"vscode:/vscode.git/clone" did not exist on "fd199a4a7c8b5bc8325a4cce44622783ff268445"
layout.cc 15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file layout/layout.cc
 *
 */

#include "layout.h"

#include <tvm/arith/pattern.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include "arith/pattern_match.h"
#include "utils.h"

namespace tvm {
namespace tl {

using namespace tir;

23
static Var getPlaceholder(const std::string &s) {
24
25
26
27
28
29
30
31
  static std::unordered_map<std::string, Var> map;
  if (map.find(s) == map.end()) {
    map[s] = Var(s);
  }
  return map[s];
}

Var ReplicationPlaceholder() { return getPlaceholder("_rep"); }
32
33
34
Var InputPlaceholder(size_t idx) {
  return getPlaceholder(std::string{'_', char('i' + idx)});
}
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

Map<Var, Range> LayoutNode::getVarMap() const {
  Map<Var, Range> map;
  for (size_t i = 0; i < InputDim(); i++) {
    map.Set(InputPlaceholder(i), {0, input_size_[i]});
  }
  return map;
}

Map<Var, Range> FragmentNode::getVarMap() const {
  auto map = LayoutNode::getVarMap();
  map.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
  return map;
}

50
51
LayoutNode::LayoutNode(Array<PrimExpr> input_size,
                       Array<PrimExpr> forward_index) {
52
53
54
  input_size_ = input_size;
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
55
56
  forward_index_ = forward_index.Map(
      [&](const PrimExpr &e) { return analyzer.Simplify(e); });
57
58
59
60
61
62
63
64
65
66
}

Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
  Map<Var, PrimExpr> vmap;
  Array<PrimExpr> input_size;
  for (size_t i = 0; i < forward_var.size(); i++) {
    vmap.Set(forward_var[i]->var, InputPlaceholder(i));
    CHECK(is_zero(forward_var[i]->dom->min));
    input_size.push_back(forward_var[i]->dom->extent);
  }
67
68
  forward_index =
      forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
69
70
71
72
73
74
75
76
77
78

  auto n = make_object<LayoutNode>(input_size, forward_index);
  data_ = std::move(n);
}

Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
  auto n = make_object<LayoutNode>(input_size, forward_index);
  data_ = std::move(n);
}

79
void LayoutNode::VisitAttrs(AttrVisitor *v) {
80
81
82
83
  v->Visit("input_size", &input_size_);
  v->Visit("forward_index", &forward_index_);
}

84
85
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
  for (const auto &[var, dom] : getVarMap()) {
86
87
88
89
    analyzer->Bind(var, dom);
  }
}

90
91
92
93
94
95
96
97
Array<PrimExpr> LayoutNode::GetForwardVars() const {
  Array<PrimExpr> vars;
  for (size_t i = 0; i < InputDim(); i++) {
    vars.push_back(InputPlaceholder(i));
  }
  return vars;
}

98
99
100
101
102
103
104
105
106
107
Array<PrimExpr> LayoutNode::OutputShape() const {
  Array<PrimExpr> ret(OutputDim(), 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  for (size_t i = 0; i < ret.size(); i++) {
    auto ist = analyzer.int_set(forward_index_[i] + 1);
    if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) {
      // X-OR Expression
      ret.Set(i, input_size_[i]);
    } else {
108
      // CHECK(is_one(ist.min())) << ist.min();
109
110
111
112
113
114
      ret.Set(i, ist.max());
    }
  }
  return ret;
}

115
116
117
Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
  if (vars.empty())
    return forward_index_;
118
119
120
121
122
  ICHECK_EQ(vars.size(), InputDim());
  Map<Var, PrimExpr> vmap;
  for (size_t i = 0; i < InputDim(); i++) {
    vmap.Set(InputPlaceholder(i), vars[i]);
  }
123
124
  return forward_index_.Map(
      [&](const PrimExpr &e) { return Substitute(e, vmap); });
125
126
}

127
128
Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
                              bool repeat_on_thread,
129
130
131
132
133
134
                              bool lower_dim_first) const {
  ICHECK_EQ(repeats.size(), InputDim());
  Array<PrimExpr> new_input_size;
  Map<Var, PrimExpr> vmap;
  for (size_t i = 0; i < InputDim(); i++) {
    new_input_size.push_back(input_size_[i] * repeats[i]);
135
136
    vmap.Set(InputPlaceholder(i),
             FloorMod(InputPlaceholder(i), InputShape()[i]));
137
138
139
140
141
  }

  PrimExpr repeats_index = 0, repeat_stride = 1;
  if (lower_dim_first) {
    for (int i = InputDim() - 1; i >= 0; i--) {
142
143
      repeats_index +=
          repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
144
145
146
147
      repeat_stride *= repeats[i];
    }
  } else {
    for (size_t i = 0; i < InputDim(); i++) {
148
149
      repeats_index +=
          repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
150
151
152
153
154
155
      repeat_stride *= repeats[i];
    }
  }

  if (repeat_on_thread) {
    PrimExpr thread_size = ThreadExtent();
156
157
158
159
160
161
    auto new_forward_index = forward_index_.Map(
        [&](const PrimExpr &e) { return Substitute(e, vmap); });
    auto new_forward_thread =
        Substitute(forward_thread_, vmap) + thread_size * repeats_index;
    return Fragment(new_input_size, new_forward_index, new_forward_thread,
                    replicate_size_, NullOpt);
162
163
164
165
166
167
  } else {
    ICHECK(OutputDim() == 1);
    PrimExpr frag_len = OutputShape()[0];
    Array<PrimExpr> new_forward_index = {Substitute(forward_index_[0], vmap) +
                                         frag_len * repeats_index};
    PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
168
169
    return Fragment(new_input_size, new_forward_index, new_forward_thread,
                    replicate_size_, NullOpt);
170
171
172
173
174
175
  }
}

Fragment FragmentNode::Replicate(int repeats) const {
  ICHECK(repeats >= 1);
  Map<Var, PrimExpr> vmap;
176
177
  vmap.Set(ReplicationPlaceholder(),
           FloorMod(ReplicationPlaceholder(), ReplicateExtent()));
178
179
180
  PrimExpr new_forward_thread =
      Substitute(forward_thread_, vmap) +
      ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent());
181
182
  return Fragment(input_size_, forward_index_, new_forward_thread,
                  ReplicateExtent() * repeats, NullOpt);
183
184
185
186
187
188
189
190
191
192
193
194
}

Fragment FragmentNode::DeReplicate() const {
  ICHECK(OutputDim() == 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  int factor = 1;
  auto rep_size = as_const_int(ReplicateExtent());
  auto idx_size = as_const_int(OutputShape()[0]);
  if (rep_size && idx_size) {
    factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
  }
195
196
  if (factor == 1)
    return GetRef<Fragment>(this);
197
198

  Map<Var, PrimExpr> vmap;
199
200
  vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
                                         FloorMod(forward_index_[0], factor));
201
202
  PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
  Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)};
203
204
  return Fragment(input_size_, new_forward_index, new_forward_thread,
                  int(*rep_size) / factor, NullOpt);
205
206
207
208
}

Layout LayoutNode::Inverse() const {
  arith::Analyzer analyzer;
209
210
211
  arith::IterMapResult res =
      arith::DetectIterMap(forward_index_, getVarMap(), 1,
                           arith::IterMapLevel::Bijective, &analyzer);
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
  ICHECK(res->errors.empty()) << res->errors;

  auto outputs_shape = OutputShape();
  Array<PrimExpr> outputs;
  for (size_t i = 0; i < OutputDim(); i++) {
    outputs.push_back(InputPlaceholder(i));
  }

  auto inv = arith::InverseAffineIterMap(res->indices, outputs);

  Array<PrimExpr> backward_index;
  for (size_t i = 0; i < InputDim(); i++) {
    if (inv.find(InputPlaceholder(i)) != inv.end()) {
      backward_index.push_back(inv[InputPlaceholder(i)]);
    } else {
      backward_index.push_back(0);
    }
  }

  return Layout(outputs_shape, backward_index);
}

234
235
236
237
238
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
                              const PrimExpr &forward_thread,
                              arith::Analyzer *analyzer) {
  Array<arith::IterSplitExpr> splits = DivideUnusedIterators(
      {forward_thread}, ToIterVars(input_iters), analyzer);
239
240

  Array<arith::IterSplitExpr> split_without_rep;
241
  for (const auto &split : splits) {
242
    CHECK(split->source->source.as<Var>());
243
244
245
    if (split->source->source.as<Var>().value().same_as(
            ReplicationPlaceholder()))
      continue;
246
247
248
249
250
    split_without_rep.push_back(split);
  }
  return MakeFlattenedExpression(split_without_rep);
}

251
252
FragmentNode::FragmentNode(Array<PrimExpr> input_size,
                           Array<PrimExpr> forward_index,
253
254
255
256
257
258
259
                           PrimExpr forward_thread, PrimExpr replicate_size) {
  input_size_ = input_size;
  replicate_size_ = replicate_size;
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  forward_thread_ = analyzer.Simplify(forward_thread);
  if (forward_index.empty()) {
260
261
    forward_index = {
        infer_fragment_index(getVarMap(), forward_thread_, &analyzer)};
262
  }
263
264
  forward_index_ = forward_index.Map(
      [&](const PrimExpr &e) { return analyzer.Simplify(e); });
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
}

Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
                   PrimExpr forward_thread, IterVar thread_replicate) {
  Map<Var, PrimExpr> vmap;
  Array<PrimExpr> input_size;
  PrimExpr replicate_size = 1;
  for (size_t i = 0; i < forward_var.size(); i++) {
    vmap.Set(forward_var[i]->var, InputPlaceholder(i));
    CHECK(is_zero(forward_var[i]->dom->min));
    input_size.push_back(forward_var[i]->dom->extent);
  }
  if (thread_replicate.defined()) {
    ICHECK(is_zero(thread_replicate->dom->min));
    replicate_size = thread_replicate->dom->extent;
    vmap.Set(thread_replicate->var, ReplicationPlaceholder());
  }
282
283
  forward_index =
      forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
284
285
  forward_thread = Substitute(forward_thread, vmap);

286
287
  auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
                                     replicate_size);
288
289
290
291
  data_ = std::move(n);
}

Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
292
293
                   PrimExpr forward_thread, PrimExpr replicate_size,
                   Optional<Var> replicate_var) {
294
  if (replicate_var.defined()) {
295
296
    forward_thread = Substitute(
        forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
297
  }
298
299
  auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
                                     replicate_size);
300
301
302
  data_ = std::move(n);
}

303
void FragmentNode::VisitAttrs(tvm::AttrVisitor *v) {
304
305
306
307
308
309
310
311
312
313
314
315
316
317
  LayoutNode::VisitAttrs(v);
  v->Visit("forward_thread", &forward_thread_);
  v->Visit("replicate_size", &replicate_size_);
}

PrimExpr FragmentNode::ThreadExtent() const {
  Array<PrimExpr> ret(OutputDim(), 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  auto ist = analyzer.int_set(forward_thread_ + 1);
  CHECK(is_one(ist.min()));
  return ist.max();
}

318
319
320
321
322
323
324
325
326
327
328
Array<PrimExpr> FragmentNode::GetForwardVars() const {
  Array<PrimExpr> vars;
  if (*as_const_int(ReplicateExtent()) > 1) {
    vars.push_back(ReplicationPlaceholder());
  }
  for (size_t i = 0; i < InputDim(); i++) {
    vars.push_back(InputPlaceholder(i));
  }
  return vars;
}

329
330
PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr> &vars,
                                     const Optional<PrimExpr> &rep_var) const {
331
332
333
334
335
  Map<Var, PrimExpr> vmap;
  ICHECK_EQ(vars.size(), InputDim());
  for (size_t i = 0; i < InputDim(); i++) {
    vmap.Set(InputPlaceholder(i), vars[i]);
  }
336
337
  if (rep_var.defined())
    vmap.Set(ReplicationPlaceholder(), rep_var.value());
338
339
340
341
342
343
344
345
346

  return Substitute(forward_thread_, vmap);
}

Layout FragmentNode::Inverse() const {
  auto input_size_copy = input_size_;
  input_size_copy.push_back(ReplicateExtent());
  auto forward_index_copy = forward_index_;
  forward_index_copy.push_back(
347
348
      Substitute(forward_thread_,
                 {{ReplicationPlaceholder(), InputPlaceholder(InputDim())}}));
349
350
351
352
353
354
355
356
357
358
359
  auto fwd = Layout(input_size_copy, forward_index_copy);
  auto bwd = fwd->Inverse();
  return bwd;
}

Fragment FragmentNode::CondenseReplicateVar() const {
  arith::Analyzer analyzer;
  auto input_iters = getVarMap();
  input_iters.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
  PrimExpr new_forward_thread;
  IterVar new_thread_replicate;
360
361
362
  std::tie(new_forward_thread, new_thread_replicate) =
      CompressIterator(forward_thread_, ToIterVars(input_iters),
                       ReplicationPlaceholder(), &analyzer);
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
  return Fragment(input_size_, forward_index_, new_forward_thread,
                  new_thread_replicate->dom->extent, new_thread_replicate->var);
}

void LayoutNode::DebugOutput() const {
  LOG_DEBUG << "Layout Shape: " << InputShape() << " -> " << OutputShape();
  LOG_DEBUG << "Layout Index: " << forward_index_;
}

void FragmentNode::DebugOutput() const {
  LOG_DEBUG << "Fragment Shape: " << InputShape() << " -> " << OutputShape();
  LOG_DEBUG << "Fragment Replicate: " << ReplicateExtent();
  LOG_DEBUG << "Fragment ThreadExtent: " << ThreadExtent();
  LOG_DEBUG << "Fragment Index: " << forward_index_;
  LOG_DEBUG << "Fragment ThreadIndex: " << forward_thread_;
}

380
381
bool LayoutNode::SEqualReduce(const LayoutNode *other,
                              SEqualReducer equal) const {
382
383
384
385
  return equal(this->InputShape(), other->InputShape()) &&
         equal(this->forward_index_, other->forward_index_);
}

386
387
bool FragmentNode::SEqualReduce(const FragmentNode *other,
                                SEqualReducer equal) const {
388
389
390
391
392
393
394
395
396
397
  return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
         equal(this->InputShape(), other->InputShape()) &&
         equal(this->ThreadExtent(), other->ThreadExtent()) &&
         equal(this->forward_index_, other->forward_index_) &&
         equal(this->forward_thread_, other->forward_thread_);
}

TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode);

398
TVM_REGISTER_GLOBAL("tl.Layout").set_body([](TVMArgs args, TVMRetValue *ret) {
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
  *ret = Layout(Array<IterVar>(args[0]), Array<PrimExpr>(args[1]));
});

TVM_REGISTER_GLOBAL("tl.Layout_input_shape").set_body_typed([](Layout layout) {
  return layout->InputShape();
});

TVM_REGISTER_GLOBAL("tl.Layout_output_shape").set_body_typed([](Layout layout) {
  return layout->OutputShape();
});

TVM_REGISTER_GLOBAL("tl.Layout_inverse").set_body_typed([](Layout layout) {
  return layout->Inverse();
});

TVM_REGISTER_GLOBAL("tl.Layout_index").set_body_typed([](Layout layout) {
  return layout->GetForwardIndex();
});

418
419
420
421
TVM_REGISTER_GLOBAL("tl.Layout_forward_vars").set_body_typed([](Layout layout) {
  return layout->GetForwardVars();
});

422
TVM_REGISTER_GLOBAL("tl.Fragment").set_body([](TVMArgs args, TVMRetValue *ret) {
423
424
425
  *ret = Fragment(args[0], args[1], args[2], args[3]);
});

426
427
TVM_REGISTER_GLOBAL("tl.Fragment_thread_size")
    .set_body_typed([](Fragment fragment) { return fragment->ThreadExtent(); });
428
429
430
431
432
433

TVM_REGISTER_GLOBAL("tl.Fragment_thread").set_body_typed([](Fragment fragment) {
  return fragment->GetForwardThread();
});

TVM_REGISTER_GLOBAL("tl.Fragment_repeat")
434
435
    .set_body_typed([](Fragment fragment, Array<PrimExpr> repeats,
                       bool repeat_on_thread, bool lower_dim_first) {
436
437
438
      return fragment->Repeat(repeats, repeat_on_thread, lower_dim_first);
    });

439
440
441
442
TVM_REGISTER_GLOBAL("tl.Fragment_replicate")
    .set_body_typed([](Fragment fragment, int repeats) {
      return fragment->Replicate(repeats);
    });
443

444
445
446
447
TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var")
    .set_body_typed([](Fragment fragment) {
      return fragment->CondenseReplicateVar();
    });
448
449
450

TVM_REGISTER_GLOBAL("tl.make_swizzled_layout")
    .set_body_typed([](int stride, int continuous, int element_size) {
451
      return makeGemmABLayout(stride, continuous, continuous, element_size, 0);
452
453
    });

454
455
} // namespace tl
} // namespace tvm