loop_partition.cc 9.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file loop_partition.cc
 * \brief Partition parallel loops onto threads
 */

#include "loop_partition.h"

#include <tvm/tir/stmt_functor.h>

29
30
#include <utility>

31
32
33
34
35
36
namespace tvm {
namespace tl {

using namespace tir;

class BufferIndiceSimplify : public StmtExprMutator {
37
38
public:
  BufferIndiceSimplify(arith::Analyzer *analyzer) : analyzer_(analyzer) {}
39

40
41
private:
  PrimExpr VisitExpr_(const BufferLoadNode *node) final {
42
    auto visited = StmtExprMutator::VisitExpr_(node);
43
    auto n = Downcast<BufferLoad>(visited);
44
    auto nptr = n.CopyOnWrite();
45
46
    nptr->indices = nptr->indices.Map(
        [&](const auto &e) { return analyzer_->Simplify(e); });
47
48
    return n;
  }
49
  Stmt VisitStmt_(const BufferStoreNode *node) final {
50
    auto visited = StmtExprMutator::VisitStmt_(node);
51
    auto n = Downcast<BufferStore>(visited);
52
    auto nptr = n.CopyOnWrite();
53
54
    nptr->indices = nptr->indices.Map(
        [&](const auto &e) { return analyzer_->Simplify(e); });
55
56
    return n;
  }
57
  arith::Analyzer *analyzer_;
58
59
60
};

// Rewrite the parallel loop into a common loop, which is mapped to threads
61
For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
62
                  const Fragment &loop_layout) {
63
64
65
66
67
68
69
70
  ICHECK(loop_layout.defined());
  ICHECK(thread_var.defined());
  int old_loop_depth = loop_layout->InputDim();
  int new_loop_depth = loop_layout->OutputDim();
  // Create the new loop iter var
  Array<Var> vars;
  for (int i = 0; i < new_loop_depth; i++) {
    Var var = Var(std::string{char('i' + i)});
71
72
    analyzer->Bind(var, Range::FromMinExtent(make_zero(var->dtype),
                                             loop_layout->OutputShape()[i]));
73
74
75
76
77
    vars.push_back(var);
  }
  vars.push_back(thread_var);
  // create the substitute map, and the loop body
  Map<Var, PrimExpr> vmap;
78
  Stmt body = std::move(op);
79
80
81
82
83
84
  Array<PrimExpr> loop_mins;
  Array<PrimExpr> loop_extents;
  auto inverse_info = loop_layout->InverseWithLevel();
  auto inv_loop = inverse_info.first;
  // Must check the guard if the layout can not be proved as bijective
  bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective;
85
  auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end()));
86
87
88
89
90
91
92
93
  // Normalize thread var once so we can reuse the same substitution later.
  Map<Var, PrimExpr> thread_offset_map;
  bool has_thread_offset = false;
  if (loop_layout->ThreadRange().defined()) {
    auto range = loop_layout->ThreadRange();
    thread_offset_map.Set(thread_var, thread_var - range->min);
    has_thread_offset = true;
  }
94
  for (int i = 0; i < old_loop_depth; i++) {
95
96
    const ForNode *loop = body.as<ForNode>();
    ICHECK(loop != nullptr);
97
    vmap.Set(loop->loop_var, indices[i]);
98
99
    loop_mins.push_back(loop->min);
    loop_extents.push_back(loop->extent);
100
101
102
103
    body = loop->body;
  }
  // substitute and re-construct the serial loop
  body = Substitute(body, vmap);
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
  // Guard executes the recovered loop body only if each inverse-mapped iterator
  // falls back into the original For ranges. We first check every axis from the
  // old loop nest (old_loop_depth) and then the extra index produced by inverse
  // layouts that carry a replicate/thread component (`inv_output_shape`). Both
  // must stay within bounds to ensure correctness. Example: layout([i, j]) =
  // floor((i * 16 + j) / 32) may generate extra points when the new loop
  // enumerates 0..31; the guard drops iterations whose inverse-mapped (i, j)
  // or replicate index fall outside their original extents.
  // Example: layout([i, j]) = floor((i * 16 + j) / 32) may produce extra points
  // when the new loop enumerates 0..31; this guard skips iterations where the
  // inverse i, j land outside the original extents. This protects
  // non-surjective loop_layout mappings that otherwise over-cover the parallel
  // space.
  PrimExpr guard = const_true();

  if (need_guard) {
    for (int i = 0; i < old_loop_depth; i++) {
      PrimExpr index = indices[i];
      if (has_thread_offset) {
        index = Substitute(index, thread_offset_map);
      }
      PrimExpr lower_bound = analyzer->Simplify(index >= loop_mins[i]);
      PrimExpr upper_bound =
          analyzer->Simplify(index < loop_mins[i] + loop_extents[i]);
      guard = And(guard, And(lower_bound, upper_bound));
    }
    auto inv_output_shape = inv_loop->OutputShape();
    if (inv_output_shape.size() > static_cast<size_t>(old_loop_depth)) {
      PrimExpr replicate_index = indices[old_loop_depth];
      if (has_thread_offset) {
        replicate_index = Substitute(replicate_index, thread_offset_map);
      }
      PrimExpr replicate_extent = inv_output_shape[old_loop_depth];
      PrimExpr lower_bound = analyzer->Simplify(
          replicate_index >= make_zero(replicate_index.dtype()));
      PrimExpr upper_bound =
          analyzer->Simplify(replicate_index < replicate_extent);
      guard = And(guard, And(lower_bound, upper_bound));
    }
    PrimExpr simplified_guard = analyzer->Simplify(guard);
    if (!analyzer->CanProve(simplified_guard)) {
      body = IfThenElse(simplified_guard, body, Stmt());
    }
  }

149
  for (int i = new_loop_depth - 1; i >= 0; i--) {
150
151
    body = For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i],
               ForKind::kSerial, body);
152
153
154
155
156
    analyzer->Bind(vars[i], Range(0, inv_loop->InputShape()[i]));
  }

  body = BufferIndiceSimplify(analyzer)(body);

157
158
  if (has_thread_offset) {
    body = Substitute(body, thread_offset_map);
159
  }
160
161

  auto for_node = LoopPragmaUnroll(Downcast<For>(body));
162
163
164
165
  return for_node;
}

class LoopPramaUnroller : public StmtExprMutator {
166
public:
167
168
  LoopPramaUnroller() = default;

169
170
private:
  Stmt VisitStmt_(const ForNode *node) final {
171
    if (node->kind == ForKind::kSerial) {
172
173
174
175
      auto analyzer = std::make_shared<arith::Analyzer>();
      if (as_const_int(analyzer->Simplify(node->extent)) == nullptr) {
        return StmtExprMutator::VisitStmt_(node);
      }
176
      For new_for = tvm::ffi::GetRef<For>(node);
177
178
179
180
181
182
183
184
185
186
      auto for_ptr = new_for.CopyOnWrite();
      for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false));
      for_ptr->kind = ForKind::kUnrolled;
      return new_for;
    }
    return StmtExprMutator::VisitStmt_(node);
  }
};

class LoopPartitioner : public StmtExprVisitor {
187
public:
188
189
  LoopPartitioner() = default;

190
  Fragment Partition(const For &op, int num_thread, int vectorize_size) {
191
    this->VisitStmt(op);
192
193
194
195
    DataType dtype = DataType::Int(32);
    if (!loop_vars_.empty()) {
      dtype = loop_vars_.back()->var.dtype();
    }
196
197
198
    PrimExpr flattened = make_const(dtype, 0);
    PrimExpr vector_extent = make_const(dtype, vectorize_size);
    PrimExpr thread_extent_const = make_const(dtype, num_thread);
199
    for (size_t i = 0; i < loop_vars_.size(); i++) {
200
      PrimExpr extent = loop_vars_[i]->dom->extent;
201
202
      flattened = flattened * extent + loop_vars_[i]->var;
    }
203
204
205
206
207
    PrimExpr access_idx = FloorDiv(flattened, vector_extent);
    PrimExpr thd = FloorMod(access_idx, thread_extent_const);
    PrimExpr idx = FloorDiv(access_idx, thread_extent_const) * vector_extent +
                   FloorMod(flattened, vector_extent);

208
209
210
211
212
213
214
215
    auto fragment = Fragment(loop_vars_, {idx}, {thd}, {});
    if (has_fragment_) {
      // for fragment buffer, we don't need to replicate the loop layout
      auto thread_extent = *as_const_int(fragment->ThreadExtent());
      auto num_thread_fragment = num_thread / thread_extent;
      fragment = fragment->Replicate(num_thread_fragment);
    }
    return fragment;
216
217
  }

218
private:
219
220
221
222
223
224
225
226
227
228
229
230
231
232
  void VisitExpr_(const BufferLoadNode *op) final {
    if (op->buffer.scope() == "local.fragment") {
      has_fragment_ = true;
    }
    StmtExprVisitor::VisitExpr_(op);
  }

  void VisitStmt_(const BufferStoreNode *op) final {
    if (op->buffer.scope() == "local.fragment") {
      has_fragment_ = true;
    }
    StmtExprVisitor::VisitStmt_(op);
  }

233
  void VisitStmt_(const ForNode *node) final {
234
235
    if (node->kind == ForKind::kParallel) {
      body_ = node->body;
236
237
238
      loop_vars_.push_back(
          IterVar(Range::FromMinExtent(node->min, node->extent), node->loop_var,
                  IterVarType::kDataPar));
239
240
241
242
243
244
    }
    StmtExprVisitor::VisitStmt_(node);
  }

  Stmt body_;
  PrimExpr flattened = 0;
245
  bool has_fragment_ = false;
246
247
248
  Array<IterVar> loop_vars_;
};

249
250
Fragment PlanLoopPartition(const For &op, size_t num_thread,
                           int vectorize_size) {
251
252
253
254
  LoopPartitioner partitioner;
  return partitioner.Partition(op, num_thread, vectorize_size);
}

255
256
Fragment PlanLoopPartition(const For &op, int vectorize_size,
                           const Range &thread_range) {
257
  size_t num_thread = *as_const_int(thread_range->extent);
258
259
  LoopPartitioner partitioner;
  Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size);
260
  return fragment->BindThreadRange(thread_range);
261
262
}

263
264
For LoopPragmaUnroll(For stmt) {
  LoopPramaUnroller unroller;
265
  For unrolled = Downcast<For>(unroller(std::move(stmt)));
266
267
268
  return unrolled;
}

269
270
} // namespace tl
} // namespace tvm