loop_partition.cc 6.65 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file loop_partition.cc
 * \brief Partition parallel loops onto threads
 */

#include "loop_partition.h"

#include <tvm/tir/stmt_functor.h>

29
30
#include <utility>

31
32
33
34
35
36
namespace tvm {
namespace tl {

using namespace tir;

class BufferIndiceSimplify : public StmtExprMutator {
37
38
public:
  BufferIndiceSimplify(arith::Analyzer *analyzer) : analyzer_(analyzer) {}
39

40
41
private:
  PrimExpr VisitExpr_(const BufferLoadNode *node) final {
42
    auto visited = StmtExprMutator::VisitExpr_(node);
43
    auto n = Downcast<BufferLoad>(visited);
44
    auto nptr = n.CopyOnWrite();
45
46
    nptr->indices = nptr->indices.Map(
        [&](const auto &e) { return analyzer_->Simplify(e); });
47
48
    return n;
  }
49
  Stmt VisitStmt_(const BufferStoreNode *node) final {
50
    auto visited = StmtExprMutator::VisitStmt_(node);
51
    auto n = Downcast<BufferStore>(visited);
52
    auto nptr = n.CopyOnWrite();
53
54
    nptr->indices = nptr->indices.Map(
        [&](const auto &e) { return analyzer_->Simplify(e); });
55
56
    return n;
  }
57
  arith::Analyzer *analyzer_;
58
59
60
};

// Rewrite the parallel loop into a common loop, which is mapped to threads
61
For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
62
                  const Fragment &loop_layout) {
63
64
65
66
67
68
69
70
71
72
73
74
75
76
  ICHECK(loop_layout.defined());
  ICHECK(thread_var.defined());
  int old_loop_depth = loop_layout->InputDim();
  int new_loop_depth = loop_layout->OutputDim();

  // Create the new loop iter var
  Array<Var> vars;
  for (int i = 0; i < new_loop_depth; i++) {
    Var var = Var(std::string{char('i' + i)});
    vars.push_back(var);
  }
  vars.push_back(thread_var);
  // create the substitute map, and the loop body
  Map<Var, PrimExpr> vmap;
77
  Stmt body = std::move(op);
78
  auto inv_loop = loop_layout->Inverse();
79
  auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end()));
80
  for (int i = 0; i < old_loop_depth; i++) {
81
82
    const ForNode *loop = body.as<ForNode>();
    ICHECK(loop != nullptr);
83
84
85
86
87
88
89
    vmap.Set(loop->loop_var, indices[i]);
    body = loop->body;
  }

  // substitute and re-construct the serial loop
  body = Substitute(body, vmap);
  for (int i = new_loop_depth - 1; i >= 0; i--) {
90
91
    body = For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i],
               ForKind::kSerial, body);
92
93
94
95
96
97
    analyzer->Bind(vars[i], Range(0, inv_loop->InputShape()[i]));
  }

  body = BufferIndiceSimplify(analyzer)(body);

  auto for_node = LoopPragmaUnroll(Downcast<For>(body));
98
99
100
101
102
103
  if (loop_layout->ThreadRange().defined()) {
    auto range = loop_layout->ThreadRange();
    auto thread_var_with_offset = thread_var - range->min;
    for_node.CopyOnWrite()->body =
        Substitute(for_node->body, {{thread_var, thread_var_with_offset}});
  }
104
105
106
107
  return for_node;
}

class LoopPramaUnroller : public StmtExprMutator {
108
public:
109
110
  LoopPramaUnroller() = default;

111
112
private:
  Stmt VisitStmt_(const ForNode *node) final {
113
114
115
116
117
118
119
120
121
122
123
124
    if (node->kind == ForKind::kSerial) {
      For new_for = GetRef<For>(node);
      auto for_ptr = new_for.CopyOnWrite();
      for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false));
      for_ptr->kind = ForKind::kUnrolled;
      return new_for;
    }
    return StmtExprMutator::VisitStmt_(node);
  }
};

class LoopPartitioner : public StmtExprVisitor {
125
public:
126
127
  LoopPartitioner() = default;

128
  Fragment Partition(const For &op, int num_thread, int vectorize_size) {
129
130
131
132
133
    this->VisitStmt(op);
    int loop_size_full = 1;
    PrimExpr flattened = 0;
    for (size_t i = 0; i < loop_vars_.size(); i++) {
      auto ext_ptr = as_const_int(loop_vars_[i]->dom->extent);
134
135
136
      ICHECK(ext_ptr)
          << "Loop partitioner only works with constant loop sizes, but got "
          << loop_vars_[i]->dom->extent;
137
138
139
140
141
142
143
      int extent = *ext_ptr;
      loop_size_full *= extent;
      flattened = flattened * extent + loop_vars_[i]->var;
    }
    ICHECK(loop_size_full % vectorize_size == 0);
    PrimExpr access_idx = FloorDiv(flattened, vectorize_size);
    PrimExpr thd = FloorMod(access_idx, num_thread);
144
145
    PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size +
                   FloorMod(flattened, vectorize_size);
146
147
148
149
150
151
152
153
    auto fragment = Fragment(loop_vars_, {idx}, {thd}, {});
    if (has_fragment_) {
      // for fragment buffer, we don't need to replicate the loop layout
      auto thread_extent = *as_const_int(fragment->ThreadExtent());
      auto num_thread_fragment = num_thread / thread_extent;
      fragment = fragment->Replicate(num_thread_fragment);
    }
    return fragment;
154
155
  }

156
private:
157
158
159
160
161
162
163
164
165
166
167
168
169
170
  void VisitExpr_(const BufferLoadNode *op) final {
    if (op->buffer.scope() == "local.fragment") {
      has_fragment_ = true;
    }
    StmtExprVisitor::VisitExpr_(op);
  }

  void VisitStmt_(const BufferStoreNode *op) final {
    if (op->buffer.scope() == "local.fragment") {
      has_fragment_ = true;
    }
    StmtExprVisitor::VisitStmt_(op);
  }

171
  void VisitStmt_(const ForNode *node) final {
172
173
    if (node->kind == ForKind::kParallel) {
      body_ = node->body;
174
175
176
      loop_vars_.push_back(
          IterVar(Range::FromMinExtent(node->min, node->extent), node->loop_var,
                  IterVarType::kDataPar));
177
178
179
180
181
182
    }
    StmtExprVisitor::VisitStmt_(node);
  }

  Stmt body_;
  PrimExpr flattened = 0;
183
  bool has_fragment_ = false;
184
185
186
  Array<IterVar> loop_vars_;
};

187
188
Fragment PlanLoopPartition(const For &op, size_t num_thread,
                           int vectorize_size) {
189
190
191
192
  LoopPartitioner partitioner;
  return partitioner.Partition(op, num_thread, vectorize_size);
}

193
194
Fragment PlanLoopPartition(const For &op, int vectorize_size,
                           const Range &thread_range) {
195
  size_t num_thread = *as_const_int(thread_range->extent);
196
197
  LoopPartitioner partitioner;
  Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size);
198
  return fragment->BindThreadRange(thread_range);
199
200
}

201
202
For LoopPragmaUnroll(For stmt) {
  LoopPramaUnroller unroller;
203
  For unrolled = Downcast<For>(unroller(std::move(stmt)));
204
205
206
  return unrolled;
}

207
208
} // namespace tl
} // namespace tvm