"docs/vscode:/vscode.git/clone" did not exist on "bbf54a8835811f96bd1e4dc4c2669f94be0bf264"
layout.cc 16.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
/*!
 * \file layout/layout.cc
 *
 */

#include "layout.h"

#include <tvm/arith/pattern.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include "arith/pattern_match.h"
#include "utils.h"

namespace tvm {
namespace tl {

using namespace tir;

20
static Var getPlaceholder(const std::string &s) {
21
22
23
24
25
26
27
28
  static std::unordered_map<std::string, Var> map;
  if (map.find(s) == map.end()) {
    map[s] = Var(s);
  }
  return map[s];
}

Var ReplicationPlaceholder() { return getPlaceholder("_rep"); }
29
30
31
Var InputPlaceholder(size_t idx) {
  return getPlaceholder(std::string{'_', char('i' + idx)});
}
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

Map<Var, Range> LayoutNode::getVarMap() const {
  Map<Var, Range> map;
  for (size_t i = 0; i < InputDim(); i++) {
    map.Set(InputPlaceholder(i), {0, input_size_[i]});
  }
  return map;
}

Map<Var, Range> FragmentNode::getVarMap() const {
  auto map = LayoutNode::getVarMap();
  map.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
  return map;
}

47
48
LayoutNode::LayoutNode(Array<PrimExpr> input_size,
                       Array<PrimExpr> forward_index) {
49
50
51
  input_size_ = input_size;
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
52
53
  forward_index_ = forward_index.Map(
      [&](const PrimExpr &e) { return analyzer.Simplify(e); });
54
55
56
57
58
59
60
61
62
63
}

Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
  Map<Var, PrimExpr> vmap;
  Array<PrimExpr> input_size;
  for (size_t i = 0; i < forward_var.size(); i++) {
    vmap.Set(forward_var[i]->var, InputPlaceholder(i));
    CHECK(is_zero(forward_var[i]->dom->min));
    input_size.push_back(forward_var[i]->dom->extent);
  }
64
65
  forward_index =
      forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
66
67
68
69
70
71
72
73
74
75

  auto n = make_object<LayoutNode>(input_size, forward_index);
  data_ = std::move(n);
}

Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
  auto n = make_object<LayoutNode>(input_size, forward_index);
  data_ = std::move(n);
}

76
void LayoutNode::VisitAttrs(AttrVisitor *v) {
77
78
79
80
  v->Visit("input_size", &input_size_);
  v->Visit("forward_index", &forward_index_);
}

81
82
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
  for (const auto &[var, dom] : getVarMap()) {
83
84
85
86
    analyzer->Bind(var, dom);
  }
}

87
88
89
90
91
92
93
94
Array<PrimExpr> LayoutNode::GetForwardVars() const {
  Array<PrimExpr> vars;
  for (size_t i = 0; i < InputDim(); i++) {
    vars.push_back(InputPlaceholder(i));
  }
  return vars;
}

95
96
97
98
99
100
101
102
103
104
Array<PrimExpr> LayoutNode::OutputShape() const {
  Array<PrimExpr> ret(OutputDim(), 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  for (size_t i = 0; i < ret.size(); i++) {
    auto ist = analyzer.int_set(forward_index_[i] + 1);
    if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) {
      // X-OR Expression
      ret.Set(i, input_size_[i]);
    } else {
105
      // CHECK(is_one(ist.min())) << ist.min();
106
107
108
109
110
111
      ret.Set(i, ist.max());
    }
  }
  return ret;
}

112
113
114
Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
  if (vars.empty())
    return forward_index_;
115
116
117
118
119
  ICHECK_EQ(vars.size(), InputDim());
  Map<Var, PrimExpr> vmap;
  for (size_t i = 0; i < InputDim(); i++) {
    vmap.Set(InputPlaceholder(i), vars[i]);
  }
120
121
  return forward_index_.Map(
      [&](const PrimExpr &e) { return Substitute(e, vmap); });
122
123
}

124
125
Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
                              bool repeat_on_thread,
126
127
128
129
130
131
                              bool lower_dim_first) const {
  ICHECK_EQ(repeats.size(), InputDim());
  Array<PrimExpr> new_input_size;
  Map<Var, PrimExpr> vmap;
  for (size_t i = 0; i < InputDim(); i++) {
    new_input_size.push_back(input_size_[i] * repeats[i]);
132
133
    vmap.Set(InputPlaceholder(i),
             FloorMod(InputPlaceholder(i), InputShape()[i]));
134
135
136
137
138
  }

  PrimExpr repeats_index = 0, repeat_stride = 1;
  if (lower_dim_first) {
    for (int i = InputDim() - 1; i >= 0; i--) {
139
140
      repeats_index +=
          repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
141
142
143
144
      repeat_stride *= repeats[i];
    }
  } else {
    for (size_t i = 0; i < InputDim(); i++) {
145
146
      repeats_index +=
          repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
147
148
149
150
151
152
      repeat_stride *= repeats[i];
    }
  }

  if (repeat_on_thread) {
    PrimExpr thread_size = ThreadExtent();
153
154
155
156
157
158
    auto new_forward_index = forward_index_.Map(
        [&](const PrimExpr &e) { return Substitute(e, vmap); });
    auto new_forward_thread =
        Substitute(forward_thread_, vmap) + thread_size * repeats_index;
    return Fragment(new_input_size, new_forward_index, new_forward_thread,
                    replicate_size_, NullOpt);
159
160
161
162
163
164
  } else {
    ICHECK(OutputDim() == 1);
    PrimExpr frag_len = OutputShape()[0];
    Array<PrimExpr> new_forward_index = {Substitute(forward_index_[0], vmap) +
                                         frag_len * repeats_index};
    PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
165
166
    return Fragment(new_input_size, new_forward_index, new_forward_thread,
                    replicate_size_, NullOpt);
167
168
169
170
171
172
  }
}

Fragment FragmentNode::Replicate(int repeats) const {
  ICHECK(repeats >= 1);
  Map<Var, PrimExpr> vmap;
173
174
  vmap.Set(ReplicationPlaceholder(),
           FloorMod(ReplicationPlaceholder(), ReplicateExtent()));
175
176
177
  PrimExpr new_forward_thread =
      Substitute(forward_thread_, vmap) +
      ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent());
178
179
  return Fragment(input_size_, forward_index_, new_forward_thread,
                  ReplicateExtent() * repeats, NullOpt);
180
181
182
183
184
185
186
187
188
189
190
191
}

Fragment FragmentNode::DeReplicate() const {
  ICHECK(OutputDim() == 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  int factor = 1;
  auto rep_size = as_const_int(ReplicateExtent());
  auto idx_size = as_const_int(OutputShape()[0]);
  if (rep_size && idx_size) {
    factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
  }
192
193
  if (factor == 1)
    return GetRef<Fragment>(this);
194
195

  Map<Var, PrimExpr> vmap;
196
197
  vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
                                         FloorMod(forward_index_[0], factor));
198
199
  PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
  Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)};
200
201
  return Fragment(input_size_, new_forward_index, new_forward_thread,
                  int(*rep_size) / factor, NullOpt);
202
203
}

204
205
206
207
208
Fragment FragmentNode::SetThreadRange(Range thread_range) {
  thread_range_ = thread_range;
  return GetRef<Fragment>(this);
}

209
210
Layout LayoutNode::Inverse() const {
  arith::Analyzer analyzer;
211
212
213
  arith::IterMapResult res =
      arith::DetectIterMap(forward_index_, getVarMap(), 1,
                           arith::IterMapLevel::Bijective, &analyzer);
214
215
  ICHECK(res->errors.empty())
      << "Layout " << DebugOutput() << " has errors: " << res->errors;
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

  auto outputs_shape = OutputShape();
  Array<PrimExpr> outputs;
  for (size_t i = 0; i < OutputDim(); i++) {
    outputs.push_back(InputPlaceholder(i));
  }

  auto inv = arith::InverseAffineIterMap(res->indices, outputs);

  Array<PrimExpr> backward_index;
  for (size_t i = 0; i < InputDim(); i++) {
    if (inv.find(InputPlaceholder(i)) != inv.end()) {
      backward_index.push_back(inv[InputPlaceholder(i)]);
    } else {
      backward_index.push_back(0);
    }
  }

  return Layout(outputs_shape, backward_index);
}

237
238
239
240
241
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
                              const PrimExpr &forward_thread,
                              arith::Analyzer *analyzer) {
  Array<arith::IterSplitExpr> splits = DivideUnusedIterators(
      {forward_thread}, ToIterVars(input_iters), analyzer);
242
243

  Array<arith::IterSplitExpr> split_without_rep;
244
  for (const auto &split : splits) {
245
    CHECK(split->source->source.as<Var>());
246
247
248
    if (split->source->source.as<Var>().value().same_as(
            ReplicationPlaceholder()))
      continue;
249
250
251
252
253
    split_without_rep.push_back(split);
  }
  return MakeFlattenedExpression(split_without_rep);
}

254
255
FragmentNode::FragmentNode(Array<PrimExpr> input_size,
                           Array<PrimExpr> forward_index,
256
257
258
259
260
261
262
                           PrimExpr forward_thread, PrimExpr replicate_size) {
  input_size_ = input_size;
  replicate_size_ = replicate_size;
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  forward_thread_ = analyzer.Simplify(forward_thread);
  if (forward_index.empty()) {
263
264
    forward_index = {
        infer_fragment_index(getVarMap(), forward_thread_, &analyzer)};
265
  }
266
267
  forward_index_ = forward_index.Map(
      [&](const PrimExpr &e) { return analyzer.Simplify(e); });
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
}

Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
                   PrimExpr forward_thread, IterVar thread_replicate) {
  Map<Var, PrimExpr> vmap;
  Array<PrimExpr> input_size;
  PrimExpr replicate_size = 1;
  for (size_t i = 0; i < forward_var.size(); i++) {
    vmap.Set(forward_var[i]->var, InputPlaceholder(i));
    CHECK(is_zero(forward_var[i]->dom->min));
    input_size.push_back(forward_var[i]->dom->extent);
  }
  if (thread_replicate.defined()) {
    ICHECK(is_zero(thread_replicate->dom->min));
    replicate_size = thread_replicate->dom->extent;
    vmap.Set(thread_replicate->var, ReplicationPlaceholder());
  }
285
286
  forward_index =
      forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
287
288
  forward_thread = Substitute(forward_thread, vmap);

289
290
  auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
                                     replicate_size);
291
292
293
294
  data_ = std::move(n);
}

Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
295
296
                   PrimExpr forward_thread, PrimExpr replicate_size,
                   Optional<Var> replicate_var) {
297
  if (replicate_var.defined()) {
298
299
    forward_thread = Substitute(
        forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
300
  }
301
302
  auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
                                     replicate_size);
303
304
305
  data_ = std::move(n);
}

306
void FragmentNode::VisitAttrs(tvm::AttrVisitor *v) {
307
308
309
310
311
312
313
314
315
316
317
318
319
320
  LayoutNode::VisitAttrs(v);
  v->Visit("forward_thread", &forward_thread_);
  v->Visit("replicate_size", &replicate_size_);
}

PrimExpr FragmentNode::ThreadExtent() const {
  Array<PrimExpr> ret(OutputDim(), 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  auto ist = analyzer.int_set(forward_thread_ + 1);
  CHECK(is_one(ist.min()));
  return ist.max();
}

321
322
323
324
325
326
327
328
329
330
331
Array<PrimExpr> FragmentNode::GetForwardVars() const {
  Array<PrimExpr> vars;
  if (*as_const_int(ReplicateExtent()) > 1) {
    vars.push_back(ReplicationPlaceholder());
  }
  for (size_t i = 0; i < InputDim(); i++) {
    vars.push_back(InputPlaceholder(i));
  }
  return vars;
}

332
333
PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr> &vars,
                                     const Optional<PrimExpr> &rep_var) const {
334
335
336
337
338
  Map<Var, PrimExpr> vmap;
  ICHECK_EQ(vars.size(), InputDim());
  for (size_t i = 0; i < InputDim(); i++) {
    vmap.Set(InputPlaceholder(i), vars[i]);
  }
339
340
  if (rep_var.defined())
    vmap.Set(ReplicationPlaceholder(), rep_var.value());
341
342
343
344
345
346
347
348
349

  return Substitute(forward_thread_, vmap);
}

Layout FragmentNode::Inverse() const {
  auto input_size_copy = input_size_;
  input_size_copy.push_back(ReplicateExtent());
  auto forward_index_copy = forward_index_;
  forward_index_copy.push_back(
350
351
      Substitute(forward_thread_,
                 {{ReplicationPlaceholder(), InputPlaceholder(InputDim())}}));
352
353
354
355
356
357
358
359
360
361
362
  auto fwd = Layout(input_size_copy, forward_index_copy);
  auto bwd = fwd->Inverse();
  return bwd;
}

Fragment FragmentNode::CondenseReplicateVar() const {
  arith::Analyzer analyzer;
  auto input_iters = getVarMap();
  input_iters.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
  PrimExpr new_forward_thread;
  IterVar new_thread_replicate;
363
364
365
  std::tie(new_forward_thread, new_thread_replicate) =
      CompressIterator(forward_thread_, ToIterVars(input_iters),
                       ReplicationPlaceholder(), &analyzer);
366
367
368
369
  return Fragment(input_size_, forward_index_, new_forward_thread,
                  new_thread_replicate->dom->extent, new_thread_replicate->var);
}

370
371
372
373
374
std::string LayoutNode::DebugOutput() const {
  std::stringstream ss;
  ss << "Layout Shape: " << InputShape() << " -> " << OutputShape() << " -> "
     << GetForwardIndex();
  return ss.str();
375
376
}

377
378
379
380
381
382
383
std::string FragmentNode::DebugOutput() const {
  std::stringstream ss;
  ss << "Fragment Shape: " << InputShape() << " -> " << OutputShape();
  ss << " -> replicate: " << ReplicateExtent();
  ss << " -> thread: " << ThreadExtent();
  ss << " -> forward_thread: " << forward_thread_;
  ss << " -> forward_index: " << GetForwardIndex();
384
  ss << " -> thread_range: " << thread_range_;
385
  return ss.str();
386
387
}

388
389
bool LayoutNode::SEqualReduce(const LayoutNode *other,
                              SEqualReducer equal) const {
390
391
392
393
  return equal(this->InputShape(), other->InputShape()) &&
         equal(this->forward_index_, other->forward_index_);
}

394
395
bool FragmentNode::SEqualReduce(const FragmentNode *other,
                                SEqualReducer equal) const {
396
397
398
399
400
401
402
  return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
         equal(this->InputShape(), other->InputShape()) &&
         equal(this->ThreadExtent(), other->ThreadExtent()) &&
         equal(this->forward_index_, other->forward_index_) &&
         equal(this->forward_thread_, other->forward_thread_);
}

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
bool LayoutNode::IsEqual(const LayoutNode *other, bool skip_index) const {
  bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
  ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
  if (!skip_index) {
    ret &= StructuralEqual()(this->forward_index_, other->forward_index_);
  }
  return ret;
}

bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const {
  // Fragment Layout Comparison can skip the index comparison
  // when the output shape is the same, as we can do
  // a[i, j] = b[j, i] in register level.

  bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
418
419
420
421
  if (!ret) {
    // may be broadcast case
    return true;
  }
422
423
424
425
426
427
428
429
430
  ret &= StructuralEqual()(this->OutputShape(), other->OutputShape());
  ret &= StructuralEqual()(this->ReplicateExtent(), other->ReplicateExtent());
  ret &= StructuralEqual()(this->ThreadExtent(), other->ThreadExtent());
  if (!skip_index) {
    ret &= StructuralEqual()(this->forward_index_, other->forward_index_);
  }
  return ret;
}

431
432
433
TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode);

434
TVM_REGISTER_GLOBAL("tl.Layout").set_body([](TVMArgs args, TVMRetValue *ret) {
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
  *ret = Layout(Array<IterVar>(args[0]), Array<PrimExpr>(args[1]));
});

TVM_REGISTER_GLOBAL("tl.Layout_input_shape").set_body_typed([](Layout layout) {
  return layout->InputShape();
});

TVM_REGISTER_GLOBAL("tl.Layout_output_shape").set_body_typed([](Layout layout) {
  return layout->OutputShape();
});

TVM_REGISTER_GLOBAL("tl.Layout_inverse").set_body_typed([](Layout layout) {
  return layout->Inverse();
});

TVM_REGISTER_GLOBAL("tl.Layout_index").set_body_typed([](Layout layout) {
  return layout->GetForwardIndex();
});

454
455
456
457
TVM_REGISTER_GLOBAL("tl.Layout_forward_vars").set_body_typed([](Layout layout) {
  return layout->GetForwardVars();
});

458
TVM_REGISTER_GLOBAL("tl.Fragment").set_body([](TVMArgs args, TVMRetValue *ret) {
459
460
461
  *ret = Fragment(args[0], args[1], args[2], args[3]);
});

462
463
TVM_REGISTER_GLOBAL("tl.Fragment_thread_size")
    .set_body_typed([](Fragment fragment) { return fragment->ThreadExtent(); });
464
465
466
467
468
469

TVM_REGISTER_GLOBAL("tl.Fragment_thread").set_body_typed([](Fragment fragment) {
  return fragment->GetForwardThread();
});

TVM_REGISTER_GLOBAL("tl.Fragment_repeat")
470
471
    .set_body_typed([](Fragment fragment, Array<PrimExpr> repeats,
                       bool repeat_on_thread, bool lower_dim_first) {
472
473
474
      return fragment->Repeat(repeats, repeat_on_thread, lower_dim_first);
    });

475
476
477
478
TVM_REGISTER_GLOBAL("tl.Fragment_replicate")
    .set_body_typed([](Fragment fragment, int repeats) {
      return fragment->Replicate(repeats);
    });
479

480
481
482
483
TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var")
    .set_body_typed([](Fragment fragment) {
      return fragment->CondenseReplicateVar();
    });
484
485
486

TVM_REGISTER_GLOBAL("tl.make_swizzled_layout")
    .set_body_typed([](int stride, int continuous, int element_size) {
487
      return makeGemmABLayout(stride, continuous, continuous, element_size, 0);
488
489
    });

490
491
} // namespace tl
} // namespace tvm