loop_vectorize.cc 8.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file loop_vectorize.cc
 * \brief A tool to automatically vectorize a for loop
 */

#include "loop_vectorize.h"

27
28
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
29
#include "common/loop_vectorization_utils.h"
30
31
32
33
34
#include "tvm/tir/analysis.h"
#include "tvm/tir/var.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
35
36
37
38
39
40
41
42
43
44
45
46
47

namespace tvm {
namespace tl {

using namespace tir;

struct VectorizePlanResult {
  int vector_size;
  bool dynamic;
  PrimExpr condition;
};

class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
48
public:
49
50
  VectorizePlanner() = default;

51
  int Plan(const For &node) {
52
53
54
55
    this->operator()(node);
    return vector_size_;
  }

56
57
private:
  void VisitStmt_(const ForNode *node) final {
58
    inner_for_ = node;
59
60
61
62
63
64
65
66
67
    auto extent_ptr = as_const_int(node->extent);
    // Here I disable dynamic shape completely,
    //   In order to do it, the Planner should accept an analyzer with
    //   arithmetic info outside to prove the dividiblity of vector size
    if (!extent_ptr) {
      vector_size_ = 1;
      return;
    }
    vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
68
69
70
    arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

71
  void VisitExpr_(const BufferLoadNode *node) final {
72
73
74
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
75
    if (node->buffer->shape.size() == 1) {
76
77
      // TODO(lei): This should be improved as
      // constant buffer that tl hack to use as local register.
78
79
80
81
      auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
      if (boundary_check && boundary_check->value == 1) {
        return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
      }
82
83
84
85
    }
    UpdateVectorSize(node->indices, node->buffer);
  }

86
  void VisitStmt_(const BufferStoreNode *node) final {
87
88
89
90
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
    UpdateVectorSize(node->indices, node->buffer);
91
    return arith::IRVisitorWithAnalyzer::VisitExpr(node->value);
92
93
  }

94
  void VisitStmt_(const IfThenElseNode *node) final {
95
96
97
98
    CheckConditionVectorized(node->condition);
    return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

99
  void VisitExpr_(const CallNode *node) final {
100
101
102
103
104
105
106
107
108
    if (node->op == builtin::if_then_else()) {
      CheckConditionVectorized(node->args[0]);
    } else if (node->op == builtin::call_extern()) {
      // do not vectorize extern calls
      vector_size_ = 1;
    }
    return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
  }

109
  void CheckConditionVectorized(const PrimExpr &cond) {
110
111
112
    // TODO: perform some checks here
  }

113
  void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) {
114
115
    if (!inner_for_)
      return;
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    // 1. Compute raw element offset
    auto strides = buffer->strides;
    if (buffer->strides.empty()) {
      PrimExpr stride = 1;
      for (int i = indices.size() - 1; i >= 0; --i) {
        strides.push_back(stride);
        stride = stride * buffer->shape[i];
      }
      strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
    }
    PrimExpr elem_offset = 0;
    for (int i = 0; i < indices.size(); ++i) {
      elem_offset += indices[i] * strides[i];
    }

    // 2. If element offset is independent with loop_var, ignore it
    if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) {
133
      return;
134
    }
135

136
137
138
    // 3. Tight vectorize bound
    vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ /
                                                         buffer->dtype.bits());
139

140
141
142
143
    // 4. Try to vectorize buffer load
    while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
                               inner_for_->extent, vector_size_, &analyzer_)) {
      vector_size_ /= 2;
144
145
146
    }
  }

147
  const int vector_load_bits_max_ = 128;
148

149
  const ForNode *inner_for_{};
150
151
152
153
154
  bool has_nonlocal_memory_access_ = false;
  int vector_size_ = 128;
};

class VectorizeRewriter : public StmtExprMutator {
155
public:
156
  VectorizeRewriter(int vector_size) : vector_size_(vector_size) {}
157

158
159
private:
  Stmt VisitStmt_(const ForNode *node) final {
160
161
    inner_for_ = node;
    auto ret = StmtExprMutator::VisitStmt_(node);
162
    if (inner_for_ == node) { // rewrite the innermost loop
163
164
165
166
167
168
169
170
      For fnode = ret.as<For>().value();
      auto old_var = fnode->loop_var;
      auto extent_ptr = as_const_int(fnode->extent);
      ICHECK(extent_ptr) << fnode->extent;
      int extent = *extent_ptr;
      ICHECK(extent % vector_size_ == 0)
          << "extent: " << extent << " vector_size_: " << vector_size_;
      ICHECK(is_zero(fnode->min));
171
172
      if (extent == vector_size_) {
        fnode.CopyOnWrite()->kind = ForKind::kVectorized;
173
        return fnode;
174
175
176
177
178
179
180
181
182
183
      } else {
        Var inner_var = Var("vec");
        Var outer_var = Var(old_var->name_hint);
        Map<Var, PrimExpr> vmap;
        vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
        Stmt body = Substitute(fnode->body, vmap);
        body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
        body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
                   fnode->thread_binding, fnode->annotations, fnode->span);
        return body;
184
185
186
187
188
189
      }
    } else {
      return ret;
    }
  }

190
  const ForNode *inner_for_{};
191
192
193
  const int vector_size_;
};

194
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
195

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
bool CanProveIndependent(const PrimExpr &expr, Var var,
                         arith::Analyzer *analyzer) {
  // 1. if var doesn't exist, it is independent
  bool used_var = UsesVar(
      expr, [&](const VarNode *v) { return GetRef<Var>(v).same_as(var); });
  if (!used_var) {
    return true;
  }
  // 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v
  Var var_1("_t", var.dtype());
  auto expr_1 = Substitute(expr, {{var, var_1}});
  if (analyzer->CanProveEqual(expr, expr_1)) {
    return true;
  }
  return false;
211
212
}

213
214
bool IndiceCanVectorize(const PrimExpr &expr, Var var,
                        const PrimExpr &iter_var_size,
215
                        int target_vectorized_size, arith::Analyzer *analyzer) {
216
  ICHECK(target_vectorized_size >= 1);
217
218
  if (target_vectorized_size == 1)
    return true;
219
220

  // Extent must be divisible
221
222
223
  if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
                               0))
    return false;
224
225
226
227
228
229
230
231

  // The base offset must be divisible
  if (!analyzer->CanProveEqual(
          FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) {
    return false;
  }

  // Bind thread range
232
233
  Var v0("v0"), v1("v1");
  analyzer->Bind(v0, Range(0, target_vectorized_size));
234
235
  analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv(
                                  iter_var_size, target_vectorized_size))));
236
237
  PrimExpr expr_transformed = analyzer->Simplify(
      Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
238
  Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
239
  PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
240
241

  // This simplify is necessary for thread region specified
242
243
  // optimizations.
  expr_vectorized = analyzer->Simplify(expr_vectorized);
244
245
246
247
248
249
250
251
252
253
254
255
  auto ramp_node = expr_vectorized.as<RampNode>();
  if (!ramp_node) {
    // Broadcast value
    if (expr_vectorized.dtype().lanes() == 1)
      return true;
    else
      return false;
  } else {
    return is_one(ramp_node->stride);
  }
}

256
For VectorizeLoop(const For &loop, int vectorize_hint) {
257
  if (vectorize_hint <= 0) {
258
259
    VectorizePlanner planner;
    vectorize_hint = planner.Plan(loop);
260
  }
261
262
  if (vectorize_hint == 1)
    return loop;
263
  auto rewriter = VectorizeRewriter(vectorize_hint);
264
265
266
  return Downcast<For>(rewriter(loop));
}

267
268
} // namespace tl
} // namespace tvm