loop_partition.cc 6.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file loop_partition.cc
 * \brief Partition parallel loops onto threads
 */

#include "loop_partition.h"

#include <tvm/tir/stmt_functor.h>

namespace tvm {
namespace tl {

using namespace tir;

class BufferIndiceSimplify : public StmtExprMutator {
35
36
public:
  BufferIndiceSimplify(arith::Analyzer *analyzer) : analyzer_(analyzer) {}
37

38
39
private:
  PrimExpr VisitExpr_(const BufferLoadNode *node) final {
40
    auto visited = StmtExprMutator::VisitExpr_(node);
41
    auto n = Downcast<BufferLoad>(visited);
42
    auto nptr = n.CopyOnWrite();
43
44
    nptr->indices = nptr->indices.Map(
        [&](const auto &e) { return analyzer_->Simplify(e); });
45
46
    return n;
  }
47
  Stmt VisitStmt_(const BufferStoreNode *node) final {
48
    auto visited = StmtExprMutator::VisitStmt_(node);
49
    auto n = Downcast<BufferStore>(visited);
50
    auto nptr = n.CopyOnWrite();
51
52
    nptr->indices = nptr->indices.Map(
        [&](const auto &e) { return analyzer_->Simplify(e); });
53
54
    return n;
  }
55
  arith::Analyzer *analyzer_;
56
57
58
};

// Rewrite the parallel loop into a common loop, which is mapped to threads
59
60
For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
                  Fragment loop_layout) {
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
  ICHECK(loop_layout.defined());
  ICHECK(thread_var.defined());
  int old_loop_depth = loop_layout->InputDim();
  int new_loop_depth = loop_layout->OutputDim();

  // Create the new loop iter var
  Array<Var> vars;
  for (int i = 0; i < new_loop_depth; i++) {
    Var var = Var(std::string{char('i' + i)});
    vars.push_back(var);
  }
  vars.push_back(thread_var);
  // create the substitute map, and the loop body
  Map<Var, PrimExpr> vmap;
  Stmt body = op;
  auto inv_loop = loop_layout->Inverse();
77
  auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end()));
78
  for (int i = 0; i < old_loop_depth; i++) {
79
80
    const ForNode *loop = body.as<ForNode>();
    ICHECK(loop != nullptr);
81
82
83
84
85
86
87
    vmap.Set(loop->loop_var, indices[i]);
    body = loop->body;
  }

  // substitute and re-construct the serial loop
  body = Substitute(body, vmap);
  for (int i = new_loop_depth - 1; i >= 0; i--) {
88
89
    body = For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i],
               ForKind::kSerial, body);
90
91
92
93
94
95
    analyzer->Bind(vars[i], Range(0, inv_loop->InputShape()[i]));
  }

  body = BufferIndiceSimplify(analyzer)(body);

  auto for_node = LoopPragmaUnroll(Downcast<For>(body));
96
97
98
99
100
101
  if (loop_layout->ThreadRange().defined()) {
    auto range = loop_layout->ThreadRange();
    auto thread_var_with_offset = thread_var - range->min;
    for_node.CopyOnWrite()->body =
        Substitute(for_node->body, {{thread_var, thread_var_with_offset}});
  }
102
103
104
105
  return for_node;
}

class LoopPramaUnroller : public StmtExprMutator {
106
public:
107
108
  LoopPramaUnroller() = default;

109
110
private:
  Stmt VisitStmt_(const ForNode *node) final {
111
112
113
114
115
116
117
118
119
120
121
122
    if (node->kind == ForKind::kSerial) {
      For new_for = GetRef<For>(node);
      auto for_ptr = new_for.CopyOnWrite();
      for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false));
      for_ptr->kind = ForKind::kUnrolled;
      return new_for;
    }
    return StmtExprMutator::VisitStmt_(node);
  }
};

class LoopPartitioner : public StmtExprVisitor {
123
public:
124
125
126
127
128
129
130
131
  LoopPartitioner() = default;

  Fragment Partition(For op, int num_thread, int vectorize_size) {
    this->VisitStmt(op);
    int loop_size_full = 1;
    PrimExpr flattened = 0;
    for (size_t i = 0; i < loop_vars_.size(); i++) {
      auto ext_ptr = as_const_int(loop_vars_[i]->dom->extent);
132
133
134
      ICHECK(ext_ptr)
          << "Loop partitioner only works with constant loop sizes, but got "
          << loop_vars_[i]->dom->extent;
135
136
137
138
139
140
141
      int extent = *ext_ptr;
      loop_size_full *= extent;
      flattened = flattened * extent + loop_vars_[i]->var;
    }
    ICHECK(loop_size_full % vectorize_size == 0);
    PrimExpr access_idx = FloorDiv(flattened, vectorize_size);
    PrimExpr thd = FloorMod(access_idx, num_thread);
142
143
    PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size +
                   FloorMod(flattened, vectorize_size);
144
145
146
147
148
149
150
151
    auto fragment = Fragment(loop_vars_, {idx}, {thd}, {});
    if (has_fragment_) {
      // for fragment buffer, we don't need to replicate the loop layout
      auto thread_extent = *as_const_int(fragment->ThreadExtent());
      auto num_thread_fragment = num_thread / thread_extent;
      fragment = fragment->Replicate(num_thread_fragment);
    }
    return fragment;
152
153
  }

154
private:
155
156
157
158
159
160
161
162
163
164
165
166
167
168
  void VisitExpr_(const BufferLoadNode *op) final {
    if (op->buffer.scope() == "local.fragment") {
      has_fragment_ = true;
    }
    StmtExprVisitor::VisitExpr_(op);
  }

  void VisitStmt_(const BufferStoreNode *op) final {
    if (op->buffer.scope() == "local.fragment") {
      has_fragment_ = true;
    }
    StmtExprVisitor::VisitStmt_(op);
  }

169
  void VisitStmt_(const ForNode *node) final {
170
171
    if (node->kind == ForKind::kParallel) {
      body_ = node->body;
172
173
174
      loop_vars_.push_back(
          IterVar(Range::FromMinExtent(node->min, node->extent), node->loop_var,
                  IterVarType::kDataPar));
175
176
177
178
179
180
    }
    StmtExprVisitor::VisitStmt_(node);
  }

  Stmt body_;
  PrimExpr flattened = 0;
181
  bool has_fragment_ = false;
182
183
184
185
186
187
188
189
  Array<IterVar> loop_vars_;
};

Fragment PlanLoopPartition(For op, size_t num_thread, int vectorize_size) {
  LoopPartitioner partitioner;
  return partitioner.Partition(op, num_thread, vectorize_size);
}

190
Fragment PlanLoopPartition(For op, int vectorize_size, Range thread_range) {
191
  size_t num_thread = *as_const_int(thread_range->extent);
192
193
194
195
196
197
198
  LoopPartitioner partitioner;
  Fragment fragment = partitioner.Partition(op, num_thread, vectorize_size);
  auto node = make_object<FragmentNode>(*fragment.get());
  node->SetThreadRange(thread_range);
  return Fragment(node);
}

199
200
201
202
203
204
For LoopPragmaUnroll(For stmt) {
  LoopPramaUnroller unroller;
  For unrolled = Downcast<For>(unroller(stmt));
  return unrolled;
}

205
206
} // namespace tl
} // namespace tvm