"test/vscode:/vscode.git/clone" did not exist on "52abc2f37112d49f85f31aa343a14bd92a83b07c"
loop_vectorize.cc 10.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file loop_vectorize.cc
 * \brief A tool to automatically vectorize a for loop
 */

#include "loop_vectorize.h"

#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>

#include <numeric>

#include "../layout/layout.h"
#include "../layout/utils.h"
35
36
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include "common/loop_vectorization_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

struct VectorizePlanResult {
  int vector_size;
  bool dynamic;
  PrimExpr condition;
};

class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
51
public:
52
53
  VectorizePlanner() = default;

54
  int Plan(const For &node) {
55
56
57
58
59
60
61
62
63
64
    this->operator()(node);
    // Always Enable vectorization
    // if (!has_nonlocal_memory_access_) return 1;
    return vector_size_;
  }

  bool GetDynamic() { return dynamic_; }

  PrimExpr GetCondition() { return condition_; }

65
66
private:
  void VisitStmt_(const ForNode *node) final {
67
68
69
70
71
    inner_for_ = node;
    iter_map_.Set(node->loop_var, Range(node->min, node->extent));
    arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

72
  void VisitExpr_(const BufferLoadNode *node) final {
73
74
75
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
76
    if (node->buffer->shape.size() == 1) {
77
78
      // TODO(lei): This should be improved as
      // constant buffer that tl hack to use as local register.
79
80
81
82
      auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
      if (boundary_check && boundary_check->value == 1) {
        return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
      }
83
84
85
86
87
    }
    UpdateVectorSize(node->indices, node->buffer);
    return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
  }

88
  void VisitStmt_(const BufferStoreNode *node) final {
89
90
91
92
93
94
95
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
    UpdateVectorSize(node->indices, node->buffer);
    return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

96
  void VisitStmt_(const IfThenElseNode *node) final {
97
98
99
100
    CheckConditionVectorized(node->condition);
    return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

101
  void VisitExpr_(const CallNode *node) final {
102
103
104
105
106
107
108
109
110
    if (node->op == builtin::if_then_else()) {
      CheckConditionVectorized(node->args[0]);
    } else if (node->op == builtin::call_extern()) {
      // do not vectorize extern calls
      vector_size_ = 1;
    }
    return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
  }

111
  void CheckConditionVectorized(const PrimExpr &cond) {
112
113
114
    // TODO: perform some checks here
  }

115
116
117
  void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
    if (!inner_for_)
      return;
118
    auto extent_ptr = inner_for_->extent.as<IntImmNode>();
119
120
    if (!extent_ptr)
      return;
121

122
    const DataType &access_type = buffer->dtype;
123
    // i // 2, i % 8 can also be vectorized as factor 16
124
    int max_vector_size = vector_load_bits_max_ / access_type.bits();
125
126
127
128
129
    // so we should disable this GCD optimization
    max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);

    auto last_dim = buffer->shape.back();
    auto mod_set = analyzer_.modular_set(last_dim);
130
131
    // when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
    // conditionally tail vectorize
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    if (buffer->shape.back().as<IntImmNode>()) {
      max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);

      auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
      // If gcd_base is equal to the last dimension,
      // we should analyze the second-to-last dimension
      // in relation to the last dimension.
      if (gcd_base < Downcast<IntImm>(last_dim)->value) {
        max_vector_size = gcd_base;
      }

      vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);

      PrimExpr elem_offset = 0;
      PrimExpr stride = 1;
      for (int i = indices.size() - 1; i >= 0; --i) {
        elem_offset = elem_offset + indices[i] * stride;
        stride = stride * buffer->shape[i];
      }
151
152
153
      while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
                                 inner_for_->extent, vector_size_,
                                 &analyzer_)) {
154
155
156
157
158
159
160
161
162
163
        vector_size_ /= 2;
      }
    } else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
      // dynamic shape load: get the vectorization condition
      dynamic_ = true;
      PrimExpr offset = buffer.OffsetOf(indices).back();
      condition_ = (FloorMod(offset, vector_size_) == 0);
    }
  }

164
  const int vector_load_bits_max_ = 128;
165

166
  const ForNode *inner_for_;
167
168
169
170
171
172
173
174
175
  Map<Var, Range> iter_map_;
  bool has_nonlocal_memory_access_ = false;
  int vector_size_ = 128;
  // conditionally vectorize
  bool dynamic_ = false;
  PrimExpr condition_;
};

class VectorizeDynamicCallRemover : public StmtExprMutator {
176
public:
177
178
179
  VectorizeDynamicCallRemover(Var inner_var, int vector_size)
      : inner_var_(inner_var), vector_size_(vector_size) {}

180
181
private:
  PrimExpr VisitExpr_(const CallNode *op) final {
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    if (op->op.same_as(builtin::if_then_else())) {
      PrimExpr cond = this->VisitExpr(op->args[0]);
      Map<Var, PrimExpr> vmap;
      // Currently remove upper bound check
      vmap.Set(inner_var_, 0);
      cond = Substitute(cond, vmap);
      Array<PrimExpr> new_args{cond, op->args[1], op->args[2]};
      return Call(op->dtype, op->op, new_args, op->span);
    } else {
      // TODO: For other calls
      return GetRef<PrimExpr>(op);
    }
  }

  Var inner_var_;
  int vector_size_;
};

class VectorizeRewriter : public StmtExprMutator {
201
public:
202
  VectorizeRewriter(VectorizePlanResult plan)
203
204
      : vector_size_(plan.vector_size), condition_(plan.condition),
        dynamic_(plan.dynamic) {}
205

206
207
private:
  Stmt VisitStmt_(const ForNode *node) final {
208
209
    inner_for_ = node;
    auto ret = StmtExprMutator::VisitStmt_(node);
210
    if (inner_for_ == node) { // rewrite the innermost loop
211
212
213
214
215
216
217
218
      For fnode = ret.as<For>().value();
      auto old_var = fnode->loop_var;
      auto extent_ptr = as_const_int(fnode->extent);
      ICHECK(extent_ptr) << fnode->extent;
      int extent = *extent_ptr;
      ICHECK(extent % vector_size_ == 0)
          << "extent: " << extent << " vector_size_: " << vector_size_;
      ICHECK(is_zero(fnode->min));
219
      if (!dynamic_) { // check dynamic shape
220
221
222
223
224
225
226
227
228
229
        if (extent == vector_size_) {
          fnode.CopyOnWrite()->kind = ForKind::kVectorized;
          return fnode;
        } else {
          Var inner_var = Var("vec");
          Var outer_var = Var(old_var->name_hint);
          Map<Var, PrimExpr> vmap;
          vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
          Stmt body = Substitute(fnode->body, vmap);
          body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
230
231
          body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
                     fnode->thread_binding, fnode->annotations, fnode->span);
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
          return body;
        }
      } else {
        Var inner_var = Var("vec");
        Var outer_var = Var(old_var->name_hint);
        Map<Var, PrimExpr> vmap;
        vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
        Stmt body = Substitute(fnode->body, vmap);
        // add condition ifthenelse here
        Map<Var, PrimExpr> vmap_condition;
        vmap_condition.Set(fnode->loop_var, outer_var * vector_size_);
        PrimExpr condition = Substitute(condition_, vmap_condition);

        VectorizeDynamicCallRemover remover(inner_var, vector_size_);
        body = remover(body);

248
249
250
251
        For vectorize_for =
            For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
        For serial_for =
            For(inner_var, 0, vector_size_, ForKind::kSerial, body);
252
        body = IfThenElse(condition, vectorize_for, serial_for);
253
254
        body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
                   fnode->thread_binding, fnode->annotations, fnode->span);
255
256
257
258
259
260
261
        return body;
      }
    } else {
      return ret;
    }
  }

262
  const ForNode *inner_for_;
263
264
265
266
267
  const int vector_size_;
  const PrimExpr condition_;
  const bool dynamic_;
};

268
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
269

270
VectorizePlanResult GetVectorizePlanResult(const For &loop) {
271
272
273
274
275
276
277
  VectorizePlanner planner;
  int vector_size = planner.Plan(loop);
  bool dynamic = planner.GetDynamic();
  PrimExpr condition = planner.GetCondition();
  return {vector_size, dynamic, condition};
}

278
279
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
                        int target_vectorized_size, arith::Analyzer *analyzer) {
280
  ICHECK(target_vectorized_size >= 1);
281
282
283
284
285
  if (target_vectorized_size == 1)
    return true;
  if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
                               0))
    return false;
286
287
288
  Var v0("v0"), v1("v1");
  analyzer->Bind(v0, Range(0, target_vectorized_size));
  analyzer->Bind(v1, Range(0, FloorDiv(iter_var_size, target_vectorized_size)));
289
290
  PrimExpr expr_transformed = analyzer->Simplify(
      Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

  Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
  PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
  auto ramp_node = expr_vectorized.as<RampNode>();
  if (!ramp_node) {
    // Broadcast value
    if (expr_vectorized.dtype().lanes() == 1)
      return true;
    else
      return false;
  } else {
    return is_one(ramp_node->stride);
  }
}

306
For VectorizeLoop(const For &loop, int vectorize_hint) {
307
308
309
310
311
  VectorizePlanResult res{128, false, 0};
  if (vectorize_hint <= 0) {
    res = GetVectorizePlanResult(loop);
    vectorize_hint = res.vector_size;
  }
312
313
  if (vectorize_hint == 1)
    return loop;
314
315
316
317
  auto rewriter = VectorizeRewriter(res);
  return Downcast<For>(rewriter(loop));
}

318
319
} // namespace tl
} // namespace tvm