loop_vectorize.cc 9.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file loop_vectorize.cc
 * \brief A tool to automatically vectorize a for loop
 */

#include "loop_vectorize.h"

#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>

#include <numeric>

#include "../layout/layout.h"
#include "../layout/utils.h"
35
36
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include "common/loop_vectorization_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

struct VectorizePlanResult {
  int vector_size;
  bool dynamic;
  PrimExpr condition;
};

class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
51
public:
52
53
  VectorizePlanner() = default;

54
  int Plan(const For &node) {
55
56
57
58
59
60
61
62
    this->operator()(node);
    return vector_size_;
  }

  bool GetDynamic() { return dynamic_; }

  PrimExpr GetCondition() { return condition_; }

63
64
private:
  void VisitStmt_(const ForNode *node) final {
65
66
    inner_for_ = node;
    iter_map_.Set(node->loop_var, Range(node->min, node->extent));
67

68
69
70
    arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

71
  void VisitExpr_(const BufferLoadNode *node) final {
72
73
74
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
75
    if (node->buffer->shape.size() == 1) {
76
77
      // TODO(lei): This should be improved as
      // constant buffer that tl hack to use as local register.
78
79
80
81
      auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
      if (boundary_check && boundary_check->value == 1) {
        return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
      }
82
83
84
85
    }
    UpdateVectorSize(node->indices, node->buffer);
  }

86
  void VisitStmt_(const BufferStoreNode *node) final {
87
88
89
90
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
    UpdateVectorSize(node->indices, node->buffer);
91
    return arith::IRVisitorWithAnalyzer::VisitExpr(node->value);
92
93
  }

94
  void VisitStmt_(const IfThenElseNode *node) final {
95
96
97
98
    CheckConditionVectorized(node->condition);
    return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

99
  void VisitExpr_(const CallNode *node) final {
100
101
102
103
104
105
106
107
108
    if (node->op == builtin::if_then_else()) {
      CheckConditionVectorized(node->args[0]);
    } else if (node->op == builtin::call_extern()) {
      // do not vectorize extern calls
      vector_size_ = 1;
    }
    return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
  }

109
  void CheckConditionVectorized(const PrimExpr &cond) {
110
111
112
    // TODO: perform some checks here
  }

113
114
115
  void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
    if (!inner_for_)
      return;
116
    auto extent_ptr = inner_for_->extent.as<IntImmNode>();
117
118
    if (!extent_ptr)
      return;
119

120
    const DataType &access_type = buffer->dtype;
121
    // i // 2, i % 8 can also be vectorized as factor 16
122
    int max_vector_size = vector_load_bits_max_ / access_type.bits();
123
124
125
126
    // so we should disable this GCD optimization
    max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
    auto last_dim = buffer->shape.back();
    auto mod_set = analyzer_.modular_set(last_dim);
127
128
    // when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
    // conditionally tail vectorize
129
130
131
132
133
134
135
136
137
138
    if (buffer->shape.back().as<IntImmNode>()) {
      max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
      auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
      // If gcd_base is equal to the last dimension,
      // we should analyze the second-to-last dimension
      // in relation to the last dimension.
      if (gcd_base < Downcast<IntImm>(last_dim)->value) {
        max_vector_size = gcd_base;
      }
      vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
139
140
141
142
143
144
145
146
147
148
149
150
151
152

      // Generate strides if not existed
      auto strides = buffer->strides;
      if (buffer->strides.size() == 0) {
        PrimExpr stride = 1;
        for (int i = indices.size() - 1; i >= 0; --i) {
          strides.push_back(stride);
          stride = stride * buffer->shape[i];
        }
        strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
      }

      // Generate and check element offset expression
      ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
153
      PrimExpr elem_offset = 0;
154
155
      for (int i = 0; i < indices.size(); ++i) {
        elem_offset += indices[i] * strides[i];
156
      }
157
158
159
      while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
                                 inner_for_->extent, vector_size_,
                                 &analyzer_)) {
160
161
162
163
164
165
166
167
168
169
        vector_size_ /= 2;
      }
    } else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
      // dynamic shape load: get the vectorization condition
      dynamic_ = true;
      PrimExpr offset = buffer.OffsetOf(indices).back();
      condition_ = (FloorMod(offset, vector_size_) == 0);
    }
  }

170
  const int vector_load_bits_max_ = 128;
171

172
  const ForNode *inner_for_;
173
174
175
176
177
178
179
180
181
  Map<Var, Range> iter_map_;
  bool has_nonlocal_memory_access_ = false;
  int vector_size_ = 128;
  // conditionally vectorize
  bool dynamic_ = false;
  PrimExpr condition_;
};

class VectorizeRewriter : public StmtExprMutator {
182
public:
183
  VectorizeRewriter(VectorizePlanResult plan)
184
185
      : vector_size_(plan.vector_size), condition_(plan.condition),
        dynamic_(plan.dynamic) {}
186

187
188
private:
  Stmt VisitStmt_(const ForNode *node) final {
189
190
    inner_for_ = node;
    auto ret = StmtExprMutator::VisitStmt_(node);
191
    if (inner_for_ == node) { // rewrite the innermost loop
192
193
194
195
196
197
198
199
      For fnode = ret.as<For>().value();
      auto old_var = fnode->loop_var;
      auto extent_ptr = as_const_int(fnode->extent);
      ICHECK(extent_ptr) << fnode->extent;
      int extent = *extent_ptr;
      ICHECK(extent % vector_size_ == 0)
          << "extent: " << extent << " vector_size_: " << vector_size_;
      ICHECK(is_zero(fnode->min));
200
      if (!dynamic_) { // check dynamic shape
201
202
203
204
205
206
207
208
209
210
        if (extent == vector_size_) {
          fnode.CopyOnWrite()->kind = ForKind::kVectorized;
          return fnode;
        } else {
          Var inner_var = Var("vec");
          Var outer_var = Var(old_var->name_hint);
          Map<Var, PrimExpr> vmap;
          vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
          Stmt body = Substitute(fnode->body, vmap);
          body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
211
212
          body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
                     fnode->thread_binding, fnode->annotations, fnode->span);
213
214
215
          return body;
        }
      } else {
216
        return fnode;
217
218
219
220
221
222
      }
    } else {
      return ret;
    }
  }

223
  const ForNode *inner_for_;
224
225
226
227
228
  const int vector_size_;
  const PrimExpr condition_;
  const bool dynamic_;
};

229
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
230

231
VectorizePlanResult GetVectorizePlanResult(const For &loop) {
232
233
234
235
236
237
238
  VectorizePlanner planner;
  int vector_size = planner.Plan(loop);
  bool dynamic = planner.GetDynamic();
  PrimExpr condition = planner.GetCondition();
  return {vector_size, dynamic, condition};
}

239
240
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
                        int target_vectorized_size, arith::Analyzer *analyzer) {
241
  ICHECK(target_vectorized_size >= 1);
242
243
  if (target_vectorized_size == 1)
    return true;
244
245

  // Extent must be divisible
246
247
248
  if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
                               0))
    return false;
249
250
251
252
253
254
255
256

  // The base offset must be divisible
  if (!analyzer->CanProveEqual(
          FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) {
    return false;
  }

  // Bind thread range
257
258
  Var v0("v0"), v1("v1");
  analyzer->Bind(v0, Range(0, target_vectorized_size));
259
260
  analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv(
                                  iter_var_size, target_vectorized_size))));
261
262
  PrimExpr expr_transformed = analyzer->Simplify(
      Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
263
  Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
264
  PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
265
266

  // This simplify is necessary for thread region specified
267
268
  // optimizations.
  expr_vectorized = analyzer->Simplify(expr_vectorized);
269
270
271
272
273
274
275
276
277
278
279
280
  auto ramp_node = expr_vectorized.as<RampNode>();
  if (!ramp_node) {
    // Broadcast value
    if (expr_vectorized.dtype().lanes() == 1)
      return true;
    else
      return false;
  } else {
    return is_one(ramp_node->stride);
  }
}

281
For VectorizeLoop(const For &loop, int vectorize_hint) {
282
283
284
285
286
  VectorizePlanResult res{128, false, 0};
  if (vectorize_hint <= 0) {
    res = GetVectorizePlanResult(loop);
    vectorize_hint = res.vector_size;
  }
287
288
  if (vectorize_hint == 1)
    return loop;
289
290
291
292
  auto rewriter = VectorizeRewriter(res);
  return Downcast<For>(rewriter(loop));
}

293
294
} // namespace tl
} // namespace tvm