layout.cc 14.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file layout/layout.cc
 *
 */

#include "layout.h"

#include <tvm/arith/pattern.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include "arith/pattern_match.h"
#include "utils.h"

namespace tvm {
namespace tl {

using namespace tir;

23
static Var getPlaceholder(const std::string &s) {
24
25
26
27
28
29
30
31
  static std::unordered_map<std::string, Var> map;
  if (map.find(s) == map.end()) {
    map[s] = Var(s);
  }
  return map[s];
}

Var ReplicationPlaceholder() { return getPlaceholder("_rep"); }
32
33
34
Var InputPlaceholder(size_t idx) {
  return getPlaceholder(std::string{'_', char('i' + idx)});
}
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

Map<Var, Range> LayoutNode::getVarMap() const {
  Map<Var, Range> map;
  for (size_t i = 0; i < InputDim(); i++) {
    map.Set(InputPlaceholder(i), {0, input_size_[i]});
  }
  return map;
}

Map<Var, Range> FragmentNode::getVarMap() const {
  auto map = LayoutNode::getVarMap();
  map.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
  return map;
}

50
51
LayoutNode::LayoutNode(Array<PrimExpr> input_size,
                       Array<PrimExpr> forward_index) {
52
53
54
  input_size_ = input_size;
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
55
56
  forward_index_ = forward_index.Map(
      [&](const PrimExpr &e) { return analyzer.Simplify(e); });
57
58
59
60
61
62
63
64
65
66
}

Layout::Layout(Array<IterVar> forward_var, Array<PrimExpr> forward_index) {
  Map<Var, PrimExpr> vmap;
  Array<PrimExpr> input_size;
  for (size_t i = 0; i < forward_var.size(); i++) {
    vmap.Set(forward_var[i]->var, InputPlaceholder(i));
    CHECK(is_zero(forward_var[i]->dom->min));
    input_size.push_back(forward_var[i]->dom->extent);
  }
67
68
  forward_index =
      forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
69
70
71
72
73
74
75
76
77
78

  auto n = make_object<LayoutNode>(input_size, forward_index);
  data_ = std::move(n);
}

Layout::Layout(Array<PrimExpr> input_size, Array<PrimExpr> forward_index) {
  auto n = make_object<LayoutNode>(input_size, forward_index);
  data_ = std::move(n);
}

79
void LayoutNode::VisitAttrs(AttrVisitor *v) {
80
81
82
83
  v->Visit("input_size", &input_size_);
  v->Visit("forward_index", &forward_index_);
}

84
85
void LayoutNode::UpdateAnalyzer(arith::Analyzer *analyzer) const {
  for (const auto &[var, dom] : getVarMap()) {
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    analyzer->Bind(var, dom);
  }
}

Array<PrimExpr> LayoutNode::OutputShape() const {
  Array<PrimExpr> ret(OutputDim(), 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  for (size_t i = 0; i < ret.size(); i++) {
    auto ist = analyzer.int_set(forward_index_[i] + 1);
    if (arith::is_neg_inf(ist.min()) && arith::is_pos_inf(ist.max())) {
      // X-OR Expression
      ret.Set(i, input_size_[i]);
    } else {
100
      // CHECK(is_one(ist.min())) << ist.min();
101
102
103
104
105
106
      ret.Set(i, ist.max());
    }
  }
  return ret;
}

107
108
109
Array<PrimExpr> LayoutNode::Forward(const Array<PrimExpr> &vars) const {
  if (vars.empty())
    return forward_index_;
110
111
112
113
114
  ICHECK_EQ(vars.size(), InputDim());
  Map<Var, PrimExpr> vmap;
  for (size_t i = 0; i < InputDim(); i++) {
    vmap.Set(InputPlaceholder(i), vars[i]);
  }
115
116
  return forward_index_.Map(
      [&](const PrimExpr &e) { return Substitute(e, vmap); });
117
118
}

119
120
Fragment FragmentNode::Repeat(const Array<PrimExpr> &repeats,
                              bool repeat_on_thread,
121
122
123
124
125
126
                              bool lower_dim_first) const {
  ICHECK_EQ(repeats.size(), InputDim());
  Array<PrimExpr> new_input_size;
  Map<Var, PrimExpr> vmap;
  for (size_t i = 0; i < InputDim(); i++) {
    new_input_size.push_back(input_size_[i] * repeats[i]);
127
128
    vmap.Set(InputPlaceholder(i),
             FloorMod(InputPlaceholder(i), InputShape()[i]));
129
130
131
132
133
  }

  PrimExpr repeats_index = 0, repeat_stride = 1;
  if (lower_dim_first) {
    for (int i = InputDim() - 1; i >= 0; i--) {
134
135
      repeats_index +=
          repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
136
137
138
139
      repeat_stride *= repeats[i];
    }
  } else {
    for (size_t i = 0; i < InputDim(); i++) {
140
141
      repeats_index +=
          repeat_stride * FloorDiv(InputPlaceholder(i), InputShape()[i]);
142
143
144
145
146
147
      repeat_stride *= repeats[i];
    }
  }

  if (repeat_on_thread) {
    PrimExpr thread_size = ThreadExtent();
148
149
150
151
152
153
    auto new_forward_index = forward_index_.Map(
        [&](const PrimExpr &e) { return Substitute(e, vmap); });
    auto new_forward_thread =
        Substitute(forward_thread_, vmap) + thread_size * repeats_index;
    return Fragment(new_input_size, new_forward_index, new_forward_thread,
                    replicate_size_, NullOpt);
154
155
156
157
158
159
  } else {
    ICHECK(OutputDim() == 1);
    PrimExpr frag_len = OutputShape()[0];
    Array<PrimExpr> new_forward_index = {Substitute(forward_index_[0], vmap) +
                                         frag_len * repeats_index};
    PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
160
161
    return Fragment(new_input_size, new_forward_index, new_forward_thread,
                    replicate_size_, NullOpt);
162
163
164
165
166
167
  }
}

Fragment FragmentNode::Replicate(int repeats) const {
  ICHECK(repeats >= 1);
  Map<Var, PrimExpr> vmap;
168
169
  vmap.Set(ReplicationPlaceholder(),
           FloorMod(ReplicationPlaceholder(), ReplicateExtent()));
170
171
172
  PrimExpr new_forward_thread =
      Substitute(forward_thread_, vmap) +
      ThreadExtent() * FloorDiv(ReplicationPlaceholder(), ReplicateExtent());
173
174
  return Fragment(input_size_, forward_index_, new_forward_thread,
                  ReplicateExtent() * repeats, NullOpt);
175
176
177
178
179
180
181
182
183
184
185
186
}

Fragment FragmentNode::DeReplicate() const {
  ICHECK(OutputDim() == 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  int factor = 1;
  auto rep_size = as_const_int(ReplicateExtent());
  auto idx_size = as_const_int(OutputShape()[0]);
  if (rep_size && idx_size) {
    factor = arith::ZeroAwareGCD(*rep_size, *idx_size);
  }
187
188
  if (factor == 1)
    return GetRef<Fragment>(this);
189
190

  Map<Var, PrimExpr> vmap;
191
192
  vmap.Set(ReplicationPlaceholder(), ReplicationPlaceholder() * factor +
                                         FloorMod(forward_index_[0], factor));
193
194
  PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
  Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)};
195
196
  return Fragment(input_size_, new_forward_index, new_forward_thread,
                  int(*rep_size) / factor, NullOpt);
197
198
199
200
}

Layout LayoutNode::Inverse() const {
  arith::Analyzer analyzer;
201
202
203
  arith::IterMapResult res =
      arith::DetectIterMap(forward_index_, getVarMap(), 1,
                           arith::IterMapLevel::Bijective, &analyzer);
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
  ICHECK(res->errors.empty()) << res->errors;

  auto outputs_shape = OutputShape();
  Array<PrimExpr> outputs;
  for (size_t i = 0; i < OutputDim(); i++) {
    outputs.push_back(InputPlaceholder(i));
  }

  auto inv = arith::InverseAffineIterMap(res->indices, outputs);

  Array<PrimExpr> backward_index;
  for (size_t i = 0; i < InputDim(); i++) {
    if (inv.find(InputPlaceholder(i)) != inv.end()) {
      backward_index.push_back(inv[InputPlaceholder(i)]);
    } else {
      backward_index.push_back(0);
    }
  }

  return Layout(outputs_shape, backward_index);
}

226
227
228
229
230
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
                              const PrimExpr &forward_thread,
                              arith::Analyzer *analyzer) {
  Array<arith::IterSplitExpr> splits = DivideUnusedIterators(
      {forward_thread}, ToIterVars(input_iters), analyzer);
231
232

  Array<arith::IterSplitExpr> split_without_rep;
233
  for (const auto &split : splits) {
234
    CHECK(split->source->source.as<Var>());
235
236
237
    if (split->source->source.as<Var>().value().same_as(
            ReplicationPlaceholder()))
      continue;
238
239
240
241
242
    split_without_rep.push_back(split);
  }
  return MakeFlattenedExpression(split_without_rep);
}

243
244
FragmentNode::FragmentNode(Array<PrimExpr> input_size,
                           Array<PrimExpr> forward_index,
245
246
247
248
249
250
251
                           PrimExpr forward_thread, PrimExpr replicate_size) {
  input_size_ = input_size;
  replicate_size_ = replicate_size;
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  forward_thread_ = analyzer.Simplify(forward_thread);
  if (forward_index.empty()) {
252
253
    forward_index = {
        infer_fragment_index(getVarMap(), forward_thread_, &analyzer)};
254
  }
255
256
  forward_index_ = forward_index.Map(
      [&](const PrimExpr &e) { return analyzer.Simplify(e); });
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
}

Fragment::Fragment(Array<IterVar> forward_var, Array<PrimExpr> forward_index,
                   PrimExpr forward_thread, IterVar thread_replicate) {
  Map<Var, PrimExpr> vmap;
  Array<PrimExpr> input_size;
  PrimExpr replicate_size = 1;
  for (size_t i = 0; i < forward_var.size(); i++) {
    vmap.Set(forward_var[i]->var, InputPlaceholder(i));
    CHECK(is_zero(forward_var[i]->dom->min));
    input_size.push_back(forward_var[i]->dom->extent);
  }
  if (thread_replicate.defined()) {
    ICHECK(is_zero(thread_replicate->dom->min));
    replicate_size = thread_replicate->dom->extent;
    vmap.Set(thread_replicate->var, ReplicationPlaceholder());
  }
274
275
  forward_index =
      forward_index.Map([&](const PrimExpr &e) { return Substitute(e, vmap); });
276
277
  forward_thread = Substitute(forward_thread, vmap);

278
279
  auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
                                     replicate_size);
280
281
282
283
  data_ = std::move(n);
}

Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
284
285
                   PrimExpr forward_thread, PrimExpr replicate_size,
                   Optional<Var> replicate_var) {
286
  if (replicate_var.defined()) {
287
288
    forward_thread = Substitute(
        forward_thread, {{replicate_var.value(), ReplicationPlaceholder()}});
289
  }
290
291
  auto n = make_object<FragmentNode>(input_size, forward_index, forward_thread,
                                     replicate_size);
292
293
294
  data_ = std::move(n);
}

295
void FragmentNode::VisitAttrs(tvm::AttrVisitor *v) {
296
297
298
299
300
301
302
303
304
305
306
307
308
309
  LayoutNode::VisitAttrs(v);
  v->Visit("forward_thread", &forward_thread_);
  v->Visit("replicate_size", &replicate_size_);
}

PrimExpr FragmentNode::ThreadExtent() const {
  Array<PrimExpr> ret(OutputDim(), 1);
  arith::Analyzer analyzer;
  UpdateAnalyzer(&analyzer);
  auto ist = analyzer.int_set(forward_thread_ + 1);
  CHECK(is_one(ist.min()));
  return ist.max();
}

310
311
PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr> &vars,
                                     const Optional<PrimExpr> &rep_var) const {
312
313
314
315
316
  Map<Var, PrimExpr> vmap;
  ICHECK_EQ(vars.size(), InputDim());
  for (size_t i = 0; i < InputDim(); i++) {
    vmap.Set(InputPlaceholder(i), vars[i]);
  }
317
318
  if (rep_var.defined())
    vmap.Set(ReplicationPlaceholder(), rep_var.value());
319
320
321
322
323
324
325
326
327

  return Substitute(forward_thread_, vmap);
}

Layout FragmentNode::Inverse() const {
  auto input_size_copy = input_size_;
  input_size_copy.push_back(ReplicateExtent());
  auto forward_index_copy = forward_index_;
  forward_index_copy.push_back(
328
329
      Substitute(forward_thread_,
                 {{ReplicationPlaceholder(), InputPlaceholder(InputDim())}}));
330
331
332
333
334
335
336
337
338
339
340
  auto fwd = Layout(input_size_copy, forward_index_copy);
  auto bwd = fwd->Inverse();
  return bwd;
}

Fragment FragmentNode::CondenseReplicateVar() const {
  arith::Analyzer analyzer;
  auto input_iters = getVarMap();
  input_iters.Set(ReplicationPlaceholder(), {0, ReplicateExtent()});
  PrimExpr new_forward_thread;
  IterVar new_thread_replicate;
341
342
343
  std::tie(new_forward_thread, new_thread_replicate) =
      CompressIterator(forward_thread_, ToIterVars(input_iters),
                       ReplicationPlaceholder(), &analyzer);
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
  return Fragment(input_size_, forward_index_, new_forward_thread,
                  new_thread_replicate->dom->extent, new_thread_replicate->var);
}

void LayoutNode::DebugOutput() const {
  LOG_DEBUG << "Layout Shape: " << InputShape() << " -> " << OutputShape();
  LOG_DEBUG << "Layout Index: " << forward_index_;
}

void FragmentNode::DebugOutput() const {
  LOG_DEBUG << "Fragment Shape: " << InputShape() << " -> " << OutputShape();
  LOG_DEBUG << "Fragment Replicate: " << ReplicateExtent();
  LOG_DEBUG << "Fragment ThreadExtent: " << ThreadExtent();
  LOG_DEBUG << "Fragment Index: " << forward_index_;
  LOG_DEBUG << "Fragment ThreadIndex: " << forward_thread_;
}

361
362
bool LayoutNode::SEqualReduce(const LayoutNode *other,
                              SEqualReducer equal) const {
363
364
365
366
  return equal(this->InputShape(), other->InputShape()) &&
         equal(this->forward_index_, other->forward_index_);
}

367
368
bool FragmentNode::SEqualReduce(const FragmentNode *other,
                                SEqualReducer equal) const {
369
370
371
372
373
374
375
376
377
378
  return equal(this->ReplicateExtent(), other->ReplicateExtent()) &&
         equal(this->InputShape(), other->InputShape()) &&
         equal(this->ThreadExtent(), other->ThreadExtent()) &&
         equal(this->forward_index_, other->forward_index_) &&
         equal(this->forward_thread_, other->forward_thread_);
}

TVM_REGISTER_NODE_TYPE(LayoutNode);
TVM_REGISTER_NODE_TYPE(FragmentNode);

379
TVM_REGISTER_GLOBAL("tl.Layout").set_body([](TVMArgs args, TVMRetValue *ret) {
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
  *ret = Layout(Array<IterVar>(args[0]), Array<PrimExpr>(args[1]));
});

TVM_REGISTER_GLOBAL("tl.Layout_input_shape").set_body_typed([](Layout layout) {
  return layout->InputShape();
});

TVM_REGISTER_GLOBAL("tl.Layout_output_shape").set_body_typed([](Layout layout) {
  return layout->OutputShape();
});

TVM_REGISTER_GLOBAL("tl.Layout_inverse").set_body_typed([](Layout layout) {
  return layout->Inverse();
});

TVM_REGISTER_GLOBAL("tl.Layout_index").set_body_typed([](Layout layout) {
  return layout->GetForwardIndex();
});

399
TVM_REGISTER_GLOBAL("tl.Fragment").set_body([](TVMArgs args, TVMRetValue *ret) {
400
401
402
  *ret = Fragment(args[0], args[1], args[2], args[3]);
});

403
404
TVM_REGISTER_GLOBAL("tl.Fragment_thread_size")
    .set_body_typed([](Fragment fragment) { return fragment->ThreadExtent(); });
405
406
407
408
409
410

TVM_REGISTER_GLOBAL("tl.Fragment_thread").set_body_typed([](Fragment fragment) {
  return fragment->GetForwardThread();
});

TVM_REGISTER_GLOBAL("tl.Fragment_repeat")
411
412
    .set_body_typed([](Fragment fragment, Array<PrimExpr> repeats,
                       bool repeat_on_thread, bool lower_dim_first) {
413
414
415
      return fragment->Repeat(repeats, repeat_on_thread, lower_dim_first);
    });

416
417
418
419
TVM_REGISTER_GLOBAL("tl.Fragment_replicate")
    .set_body_typed([](Fragment fragment, int repeats) {
      return fragment->Replicate(repeats);
    });
420

421
422
423
424
TVM_REGISTER_GLOBAL("tl.Fragment_condense_rep_var")
    .set_body_typed([](Fragment fragment) {
      return fragment->CondenseReplicateVar();
    });
425
426
427
428
429
430

TVM_REGISTER_GLOBAL("tl.make_swizzled_layout")
    .set_body_typed([](int stride, int continuous, int element_size) {
      return makeGemmABLayout(stride, continuous, element_size, 0);
    });

431
432
} // namespace tl
} // namespace tvm