utils.cc 8.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*!
 * \file layout/utils.cc
 * \brief Some arith tools for layout & fragment inference
 *
 */

#include "utils.h"

#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

namespace tvm {
namespace tl {

using namespace tir;
using namespace arith;

18
19
20
bool CanProveDivisible(const PrimExpr &lhs, const PrimExpr &rhs) {
  const auto *clhs = lhs.as<IntImmNode>();
  const auto *crhs = rhs.as<IntImmNode>();
21
22
23
24
25
26
27
28
29
30
31
32
  if (crhs && crhs->value == 0) {
    return false;
  } else if (clhs && crhs) {
    return clhs->value % crhs->value == 0;
  }

  return false;
}

/*!
 * \brief Collector that collects the outgoing split reference of each IterMark.
 *
33
34
 *  These out-going splits can then be used to check if the iterators are
 * independent.
35
36
 */
class IterMarkSplitCollector {
37
public:
38
39
40
  // mark all IterMarks that are visited.
  std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
  // each iter mark to its outgoing splits that are referenced.
41
42
  std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash,
                     ObjectPtrEqual>
43
44
45
46
47
      mark2splits_;
  /*!
   * \brief Collect all mark2splits recursively from indices.
   * \param indices The iterator of interest.
   */
48
  void Collect(const Array<IterSumExpr> &indices) {
49
50
51
52
53
54
55
56
    for (IterSumExpr sum_expr : indices) {
      for (IterSplitExpr split : sum_expr->args) {
        this->CollectInternal(split->source);
        mark2splits_[split->source].push_back(split);
      }
    }
  }

57
58
59
  void CollectInternal(const IterMark &mark) {
    if (visited_.count(mark))
      return;
60
    visited_.insert(mark);
61
    if (auto *op = mark->source.as<IterSumExprNode>()) {
62
63
64
65
66
67
68
69
      for (IterSplitExpr split : op->args) {
        this->CollectInternal(split->source);
        mark2splits_[split->source].push_back(split);
      }
    }
  }
};

70
71
72
Array<IterSplitExpr> get_unused_iters(const IterMark &mark,
                                      const std::vector<IterSplitExpr> &splits,
                                      Analyzer *analyzer) {
73
74
75
76
77
78
79
80
  PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
  std::vector<bool> used(splits.size(), false);
  std::vector<IterSplitExpr> results;
  size_t i = 0;
  for (; i < splits.size();) {
    size_t j = 0;
    size_t lowest = splits.size();
    for (; j < splits.size(); ++j) {
81
82
83
84
      if (used[j])
        continue;
      if (!used[j] && analyzer->CanProveEqual(splits[j]->lower_factor,
                                              expected_lower_factor)) {
85
86
87
        break;
      }
      if (lowest == splits.size() ||
88
89
          CanProveDivisible(splits[lowest]->lower_factor,
                            splits[j]->lower_factor)) {
90
91
92
93
94
        lowest = j;
      }
    }
    if (j == splits.size()) {
      ICHECK(lowest != splits.size());
95
96
97
98
99
      ICHECK(CanProveDivisible(splits[lowest]->lower_factor,
                               expected_lower_factor));
      results.emplace_back(
          mark, expected_lower_factor,
          FloorDiv(splits[lowest]->lower_factor, expected_lower_factor), 1);
100
101
102
103
104
105
106
      expected_lower_factor = splits[lowest]->lower_factor;
    } else {
      used[j] = true;
      i++;
      expected_lower_factor = splits[j]->lower_factor * splits[j]->extent;
    }
  }
107
108
  bool match_full_iter =
      analyzer->CanProveEqual(expected_lower_factor, mark->extent);
109
  if (!match_full_iter) {
110
111
    results.emplace_back(mark, expected_lower_factor,
                         FloorDiv(mark->extent, expected_lower_factor), 1);
112
113
114
115
  }
  return results;
}

116
117
118
119
120
121
Array<IterSplitExpr> DivideUnusedIterators(const Array<PrimExpr> &exprs,
                                           const Array<IterVar> input_iters,
                                           Analyzer *analyzer) {
  auto iter_sum = exprs.Map([&](const auto &e) {
    return NormalizeToIterSum(e, ToVMap(input_iters), analyzer);
  });
122
123
124
125
  IterMarkSplitCollector collector;
  collector.Collect(iter_sum);
  Array<IterSplitExpr> results;

126
  for (const IterMark &mark : collector.visited_) {
127
128
129
130
131
    if (!mark->source.as<Var>()) {
      std::ostringstream oss;
      oss << "Not a normalized iterator: " << mark;
      throw NormalizeIterException(oss.str());
    }
132
133
  }

134
  for (const IterVar &iter : input_iters) {
135
    IterMark iv_mark;
136
    for (const IterMark &mark : collector.visited_) {
137
      if (mark->source.as<Var>()->same_as(iter->var)) {
138
139
140
141
142
        iv_mark = mark;
        break;
      }
    }
    if (iv_mark.defined()) {
143
144
      auto splits =
          get_unused_iters(iv_mark, collector.mark2splits_[iv_mark], analyzer);
145
146
147
148
149
150
151
152
153
154
155
      // Put the small axis last
      results.insert(results.end(), splits.rbegin(), splits.rend());
    } else if (!is_one(iter->dom->extent)) {
      auto mark = IterMark(iter->var, iter->dom->extent);
      auto split = IterSplitExpr(mark, 1, iter->dom->extent, 1);
      results.push_back(split);
    }
  }
  return results;
}

156
PrimExpr MakeFlattenedExpression(const Array<arith::IterSplitExpr> &splits) {
157
158
159
  Array<arith::IterSplitExpr> lists;
  PrimExpr scale = 1;
  for (int i = splits.size() - 1; i >= 0; i--) {
160
161
    auto scaled_split = arith::IterSplitExpr(
        splits[i]->source, splits[i]->lower_factor, splits[i]->extent, scale);
162
163
164
165
166
167
168
    lists.push_back(scaled_split);
    scale *= splits[i]->extent;
  }
  return arith::NormalizeIterMapToExpr(arith::IterSumExpr(lists, 0));
}

class IterSumMutator {
169
170
public:
  IterSumMutator(const Map<IterSplitExpr, IterSplitExpr> &replace_map)
171
172
173
      : replace_map_(replace_map) {}

  // override the original mutate function.
174
  IterSumExpr Mutate(const IterSumExpr &iter_sum) {
175
    Array<IterSplitExpr> args;
176
    for (const auto &split : iter_sum->args) {
177
178
179
      if (replace_map_.count(split)) {
        args.push_back(replace_map_[split]);
      } else {
180
181
        auto split_ = IterSplitExpr(Mutate(split->source), split->lower_factor,
                                    split->extent, split->scale);
182
183
184
185
186
187
        args.push_back(split_);
      }
    }
    return IterSumExpr(args, iter_sum->base);
  }

188
189
  IterMark Mutate(const IterMark &mark) {
    if (auto *op = mark->source.as<IterSumExprNode>()) {
190
191
192
193
194
195
      return IterMark(Mutate(GetRef<IterSumExpr>(op)), mark->extent);
    } else {
      return mark;
    }
  }

196
private:
197
198
199
  Map<IterSplitExpr, IterSplitExpr> replace_map_;
};

200
201
202
203
204
205
std::pair<PrimExpr, IterVar> CompressIterator(const PrimExpr &expr,
                                              const Array<IterVar> input_iters,
                                              const Var &var,
                                              arith::Analyzer *analyzer) {
  auto iter_sum =
      arith::NormalizeToIterSum(expr, ToVMap(input_iters), analyzer);
206
207
208
  IterMarkSplitCollector collector;
  collector.Collect({iter_sum});
  IterMark mark;
209
  for (const IterMark &m : collector.visited_) {
210
211
212
213
214
215
216
217
218
219
220
221
    ICHECK(m->source.as<Var>()) << "Not a normalized iterator: " << mark;
    if (m->source.as<Var>().value().same_as(var)) {
      mark = m;
      break;
    }
  }
  std::vector<tvm::arith::IterSplitExpr> splits;
  if (mark.defined()) {
    splits = collector.mark2splits_[mark];
  }

  PrimExpr extent = 1;
222
  for (const auto &split : splits) {
223
224
225
226
227
228
229
230
231
    extent *= split->extent;
  }
  extent = analyzer->Simplify(extent);

  auto new_var = Var(var->name_hint, var->type_annotation);
  auto new_iter_var = IterVar(Range(0, extent), new_var, IterVarType::kDataPar);
  auto new_mark = IterMark(new_var, extent);
  PrimExpr scale = 1;
  Map<IterSplitExpr, IterSplitExpr> replace_map;
232
233
234
  for (const auto &split : splits) {
    auto rescaled =
        arith::IterSplitExpr(new_mark, scale, split->extent, split->scale);
235
236
237
238
239
    replace_map.Set(split, rescaled);
    scale *= split->extent;
  }

  IterSumMutator mutator(replace_map);
240
241
  PrimExpr reaplced =
      analyzer->Simplify(NormalizeIterMapToExpr(mutator.Mutate(iter_sum)));
242
243
244
245

  return {reaplced, new_iter_var};
}

246
Array<IterVar> ToIterVars(const Map<Var, Range> &vmap) {
247
  Array<IterVar> result;
248
  for (const auto &[var, range] : vmap) {
249
250
251
252
253
    result.push_back(IterVar(range, var, IterVarType::kDataPar));
  }
  return result;
}

254
Map<Var, Range> ToVMap(const Array<IterVar> &ivs) {
255
  Map<Var, Range> result;
256
  for (const auto &iv : ivs) {
257
258
259
260
261
    result.Set(iv->var, iv->dom);
  }
  return result;
}

262
263
} // namespace tl
} // namespace tvm