elem.cc 13.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file tl/op/elem.cc
 *
 * Define elment-wise operators.
 */

#include "elem.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>

#include "../target/utils.h"
17
#include "../transform/common/loop_fusion_utils.h"
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"

namespace tvm {
namespace tl {

using namespace tir;

Copy::Copy(Array<PrimExpr> args, BufferMap vmap) : args_(args) {
  Array<Range> rgs[2];
  Buffer bf[2];
  for (int i = 0; i < 2; i++) {
    auto expr = args[i];
    auto call = expr.as<CallNode>();
    ICHECK(call);
    auto region = RegionOp(call->args, vmap);
    rgs[i] = region.GetRanges();
    bf[i] = region.GetBuffer();
  }
  std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]);
  std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]);
40
  if (args.size() >= 3) {
41
42
43
44
45
46
47
48
    coalesced_width = Downcast<IntImm>(args[2]);
  }
}

Array<IterVar> Copy::MakeIterVars() const {
  Array<IterVar> loop_vars;
  size_t idx = 0;
  for (size_t i = 0; i < src_range.size(); i++) {
49
50
    if (is_one(src_range[i]->extent))
      continue;
51
52
    Var var = Var(std::string{char('i' + idx)});
    idx++;
53
54
    loop_vars.push_back(
        {Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
55
56
57
58
59
60
  }
  return loop_vars;
}

// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
61
62
Array<PrimExpr> Copy::MakeIndices(const Array<IterVar> &ivs,
                                  int src_dst) const {
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
  Array<PrimExpr> indices;
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
    if (is_one(ranges[i]->extent))
      indices.push_back(ranges[i]->min);
    else {
      indices.push_back(ranges[i]->min + ivs[idx]->var);
      idx++;
    }
  }
  ICHECK(idx == ivs.size());
  return indices;
}

78
79
80
PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer,
                             const Array<IterVar> &ivs, Array<PrimExpr> extents,
                             int src_dst) const {
81
82
83
84
85
  Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
  Array<PrimExpr> cond_list;
  ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
  size_t idx = 0;
  for (size_t i = 0; i < ranges.size(); i++) {
86
87
    if (is_one(ranges[i]->extent))
      continue;
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    cond = ranges[i]->min + ivs[idx]->var >= 0;
    if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
      cond_list.push_back(cond);
    }
    idx++;
  }
  if (cond_list.empty())
    return {};
  else {
    PrimExpr cond = cond_list[0];
102
103
    for (size_t i = 1; i < cond_list.size(); i++)
      cond = And(cond, cond_list[i]);
104
105
106
107
    return cond;
  }
}

108
For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
109
  Array<IterVar> loop_vars = MakeIterVars();
110
111
  for (const auto &iv : loop_vars)
    analyzer->Bind(iv->var, iv->dom);
112
113
114
115
116
117
118
119

  Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
  Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);

  PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
  PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);

  PrimExpr value = BufferLoad(src, src_indices);
120
121
122
123
  if (src->dtype != dst->dtype)
    value = Cast(dst->dtype, value);
  if (src_predicate.defined())
    value = if_then_else(src_predicate, value, make_zero(dst->dtype));
124
125

  Stmt body = BufferStore(dst, value, dst_indices);
126
127
  if (dst_predicate.defined())
    body = IfThenElse(dst_predicate, body);
128
129
130

  for (int i = loop_vars.size() - 1; i >= 0; i--) {
    Map<String, ObjectRef> annotations = {};
131
    if (coalesced_width.defined()) {
132
133
      annotations.Set("coalesced_width", coalesced_width);
    }
134
135
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body, NullOpt, annotations);
136
137
138
139
  }
  return Downcast<For>(body);
}

140
Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
141
142
  Target target = T.target;
  bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
143
  Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
144
145
  if (ldsm_stmt.defined())
    return ldsm_stmt;
146
147

  Stmt bulk_copy_stmt = LowerBulkCopy(T, analyzer);
148
149
  if (bulk_copy_stmt.defined())
    return bulk_copy_stmt;
150
151
152
  auto simt_loop = MakeSIMTLoop(analyzer);
  auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));

153
  For vectorized_thread_loop;
154
  auto par_op = std::make_unique<ParallelOp>(fused_loop);
155
156
157
158
159
160
161
162
163
164
165

  if (is_cpu_target) {
    vectorized_thread_loop = VectorizeLoop(fused_loop);
  } else {
    par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap},
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
    vectorized_thread_loop = VectorizeLoop(thread_loop);
  }

166
  if (par_op->GetPredicate(T.thread_var).defined()) {
167
168
    return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                      vectorized_thread_loop);
169
170
171
172
173
  }

  return vectorized_thread_loop;
}

174
Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
175
176
177
178
179
180
181
182
183
184
185
186
187
188
  // Check buffer scope
  bool is_ldmatrix;
  if (TargetHasLdmatrix(T.target) && src.scope() == "shared.dyn" &&
      dst.scope() == "local.fragment") {
    is_ldmatrix = true;
  } else if (TargetHasStmatrix(T.target) && dst.scope() == "shared.dyn" &&
             src.scope() == "local.fragment") {
    is_ldmatrix = false;
  } else {
    return Stmt();
  }

  // Check no predicates
  Array<IterVar> loop_vars = MakeIterVars();
189
190
191
192
  if (loop_vars.size() < 2)
    return Stmt();
  for (const auto &iv : loop_vars)
    analyzer->Bind(iv->var, iv->dom);
193
194
  PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
  PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
195
196
  if (src_predicate.defined() || dst_predicate.defined())
    return Stmt();
197
198
199
200
201
202

  Buffer shared_tensor = is_ldmatrix ? src : dst;
  Buffer local_tensor = is_ldmatrix ? dst : src;

  Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0);
  Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
203
204
  Array<PrimExpr> local_indices_transformed =
      local_layout->Forward(local_indices);
205
206
  local_tensor = T.buffer_remap[local_tensor];
  // currently only support 1-d case
207
208
  if (local_layout->OutputDim() != 1)
    return Stmt();
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

  Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1);
  Array<PrimExpr> shared_indices_transformed = shared_indices;
  Layout shared_layout;
  if (T.buffer_remap.count(shared_tensor)) {
    shared_layout = T.layout_map[shared_tensor];
    shared_tensor = T.buffer_remap[shared_tensor];
    shared_indices_transformed = shared_layout->Forward(shared_indices);
  }

  // Check local_layout follows 8x8 layout
  bool is_transposed;
  IterVar col_var = loop_vars[loop_vars.size() - 1];
  IterVar row_var = loop_vars[loop_vars.size() - 2];
  PrimExpr local_layout_thread_map =
      FloorMod(local_layout->ForwardThread(local_indices, NullOpt), 32);
225
  PrimExpr matrix_8x8_thread_map = makeGemmFragment8x8()->ForwardThread(
226
      {FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
227
228
229
230
231
  PrimExpr matrix_8x8_thread_map_trans =
      makeGemmFragment8x8Transposed()->ForwardThread(
          {FloorMod(row_var, 8), FloorMod(col_var, 8)}, NullOpt);
  PrimExpr local_indices_flattened =
      local_tensor.OffsetOf(local_indices_transformed).back();
232
  if (analyzer->CanProveEqual(matrix_8x8_thread_map, local_layout_thread_map) &&
233
234
      IndiceCanVectorize(local_indices_flattened, col_var->var,
                         col_var->dom->extent, 2, analyzer)) {
235
    is_transposed = false;
236
237
238
239
  } else if (analyzer->CanProveEqual(matrix_8x8_thread_map_trans,
                                     local_layout_thread_map) &&
             IndiceCanVectorize(local_indices_flattened, row_var->var,
                                row_var->dom->extent, 2, analyzer)) {
240
241
242
243
244
    is_transposed = true;
  } else {
    return Stmt();
  }
  // Check shared_layout is 16 bytes continuous
245
246
247
248
249
250
  if (shared_tensor->dtype.bytes() != 2)
    return Stmt();
  PrimExpr flattened_indice =
      shared_tensor.OffsetOf(shared_indices_transformed).back();
  if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var,
                          loop_vars.back()->dom->extent, 8, analyzer))
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    return Stmt();

  // Can only support local_range to be a full range
  for (size_t i = 0; i < dst_range.size(); i++) {
    if (!is_zero(dst_range[i]->min) ||
        !analyzer->CanProveEqual(dst_range[i]->extent, dst->shape[i]))
      return Stmt();
  }

  // Do the lowering here, try vectorized ldmatrix/stmatrix by 4/2/1
  PrimExpr extent = local_tensor->shape[0];
  int num = 1;
  if (analyzer->CanProveEqual(FloorMod(extent, 8), 0))
    num = 4;
  else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0))
    num = 2;

  Array<PrimExpr> args;
269
  const Op &op = is_ldmatrix ? tl::LDMatrixOp() : tl::STMatrixOp();
270
271
272
273
274
275
276
  args.push_back(static_cast<int>(is_transposed));
  args.push_back(num);

  // Create shared address with regard to local address
  // if not transpose
  // coords = Inverse(base + 2 * (thread / 8) % num, warp + (thread % 8) * 4))
  // if transpose
277
278
  // coords = Inverse(base + 2 * (thread / 8) % num + thread % 2, warp + thread
  // % 8 / 2)
279
280
281
282
283
  Var local_iter("i");
  Layout inv = local_layout->Inverse();
  Array<PrimExpr> shared_coords;
  PrimExpr warp = FloorDiv(T.thread_var, 32) * 32;
  if (!is_transposed)
284
285
286
    shared_coords = inv->Forward(
        {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
         warp + FloorMod(T.thread_var, 8) * 4});
287
  else
288
289
290
291
292
293
294
    shared_coords = inv->Forward(
        {local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
             FloorMod(T.thread_var, 2),
         warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
  shared_coords.pop_back(); // remove rep
  if (shared_layout.defined())
    shared_coords = shared_layout->Forward(shared_coords);
295
  PrimExpr shared_addr = shared_tensor.access_ptr(
296
297
      is_ldmatrix ? 1 : 2, DataType::Handle(), 1,
      shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
298
299
300
301
  args.push_back(shared_addr);

  if (is_ldmatrix) {
    // Can only support same dtype for ldmatrx
302
303
304
305
    if (local_tensor->dtype != shared_tensor->dtype)
      return Stmt();
    PrimExpr local_addr = local_tensor.access_ptr(
        2, DataType::Handle(), 1, local_iter * 2 * num, PrimExpr(2 * num));
306
307
308
    args.push_back(local_addr);
  } else {
    for (int i = 0; i < num; i++) {
309
310
311
312
      PrimExpr value0 =
          BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i});
      PrimExpr value1 =
          BufferLoad(local_tensor, {local_iter * 2 * num + 2 * i + 1});
313
314
315
316
      if (local_tensor->dtype != shared_tensor->dtype) {
        value0 = Cast(shared_tensor->dtype, value0);
        value1 = Cast(shared_tensor->dtype, value1);
      }
317
318
      PrimExpr value_packed =
          Call(DataType::Int(32), PackB16Op(), {value0, value1});
319
320
321
322
323
      args.push_back(value_packed);
    }
  }

  auto body = Evaluate(Call(DataType::Handle(), op, args));
324
325
  For for_node =
      For(local_iter, 0, FloorDiv(extent, 2 * num), ForKind::kSerial, body);
326
327
328
329
  for_node = LoopPragmaUnroll(for_node);
  return for_node;
}

330
LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) {
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
  // Use parallel op to infer the layout
  if (par_op_ == nullptr) {
    arith::Analyzer analyzer;
    par_op_ = std::make_unique<ParallelOp>(MakeSIMTLoop(&analyzer));
  }
  return par_op_->InferLayout(T, level);
}

Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
  dst = vmap[GetVarFromAccessPtr(args[0])];
  if (args[1]->dtype != dst->dtype) {
    value = Cast(dst->dtype, args[1]);
  } else {
    value = args[1];
  }
}

348
For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
349
350
351
352
353
354
355
356
357
358
  int ndim = dst->shape.size();
  Array<IterVar> loop_vars;
  Array<PrimExpr> dst_indices;
  for (int i = 0; i < ndim; i++) {
    Var var = Var(std::string{char('i' + i)});
    loop_vars.push_back({Range(0, dst->shape[i]), var, IterVarType::kDataPar});
    dst_indices.push_back(var);
  }
  Stmt body = BufferStore(dst, value, dst_indices);
  for (int i = ndim - 1; i >= 0; i--) {
359
360
    body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent,
               ForKind::kParallel, body);
361
362
363
364
  }
  return Downcast<For>(body);
}

365
Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
366
367
368

  if (dst.scope() == "local.fragment") {
    auto par_op = std::make_unique<ParallelOp>(MakeSIMTLoop(analyzer));
369
370
371
372
373
374
    par_op->InferLayout({T.target, T.block_size, T.layout_map},
                        InferLevel::kFree);
    par_op->InferLayout({T.target, T.block_size, T.layout_map},
                        InferLevel::kFree);
    auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
                                     par_op->GetLoopLayout());
375
376
    auto vectorized_thread_loop = VectorizeLoop(thread_loop);
    if (par_op->GetPredicate(T.thread_var).defined()) {
377
378
      return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
                        vectorized_thread_loop);
379
380
381
382
383
384
385
386
387
388
389
390
391
    }
    return vectorized_thread_loop;
  } else if (dst.scope() == "local") {
    auto init_loop = MakeSIMTLoop(analyzer);
    auto vectorized_thread_loop = VectorizeLoop(init_loop);
    return vectorized_thread_loop;
  } else {
    LOG(FATAL) << "Unsupported scope " << dst.scope();
  }
}

TIR_REGISTER_TL_OP(Copy, copy)
    .set_num_inputs(3)
392
393
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
394
395
396

TIR_REGISTER_TL_OP(Fill, fill)
    .set_num_inputs(2)
397
398
    .set_attr<TCallEffectKind>("TCallEffectKind",
                               Integer(CallEffectKind::kOpaque));
399

400
401
} // namespace tl
} // namespace tvm