loop_vectorize.cc 12.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file loop_vectorize.cc
 * \brief A tool to automatically vectorize a for loop
 */

#include "loop_vectorize.h"
26
27
#include "../op/builtin.h"
#include "../target/utils.h"
28
29
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
30
#include "common/loop_vectorization_utils.h"
31
32
33
34
35
#include "tvm/tir/analysis.h"
#include "tvm/tir/var.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
36
37
38
39
40
41
42
43
44
45
46
47

namespace tvm {
namespace tl {

using namespace tir;

struct VectorizePlanResult {
  int vector_size;
  bool dynamic;
  PrimExpr condition;
};

48
class VectorizeFindGlobalAccess : public StmtExprVisitor {
49
50
51
52
53
54
55
56
57
58
59
60
61
62
public:
  VectorizeFindGlobalAccess() = default;

  bool HasGlobalAccess(const Stmt &stmt) {
    this->operator()(stmt);
    return has_global_access_;
  }

private:
  bool has_global_access_ = false;

  void VisitStmt_(const BufferStoreNode *node) final {
    if (node->buffer.scope() == "global")
      has_global_access_ = true;
63
    return StmtExprVisitor::VisitStmt_(node);
64
65
66
67
68
  }

  void VisitExpr_(const BufferLoadNode *node) final {
    if (node->buffer.scope() == "global")
      has_global_access_ = true;
69
    return StmtExprVisitor::VisitExpr_(node);
70
71
72
  }
};

73
class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
74
public:
75
76
  explicit VectorizePlanner(arith::Analyzer *analyzer)
      : arith::IRMutatorWithAnalyzer(analyzer) {}
77

78
  int Plan(const For &node) {
79
80
81
82
83
84
85
86
87
88
89
90
    tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
    Optional<Bool> opt_disable_vectorize_256 =
        ctxt->GetConfig(kDisableVectorize256, Optional<Bool>());
    bool disable_vectorize_256 =
        opt_disable_vectorize_256.value_or(Bool(false));
    if (tvm::tl::TargetIsSm100(Target::Current(false)) &&
        !disable_vectorize_256 &&
        VectorizeFindGlobalAccess().HasGlobalAccess(node)) {
      vector_load_bits_max_ = vector_size_ = 256;
    } else {
      vector_load_bits_max_ = vector_size_ = 128;
    }
91
92
93
94
    this->operator()(node);
    return vector_size_;
  }

95
private:
96
  Stmt VisitStmt_(const ForNode *node) final {
97
    inner_for_ = node;
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    bool contains_nested_for = false;
    // Must analysis vectorization on the innermost loop
    PostOrderVisit(Downcast<Stmt>(node->body), [&](const ObjectRef &obj) {
      if (obj.as<ForNode>()) {
        contains_nested_for = true;
      }
    });

    if (!contains_nested_for) {
      auto extent_ptr = as_const_int(analyzer_->Simplify(node->extent));
      // Here I disable dynamic shape completely,
      //   In order to do it, the Planner should accept an analyzer with
      //   arithmetic info outside to prove the dividiblity of vector size
      if (!extent_ptr) {
        vector_size_ = 1;
        return ffi::GetRef<Stmt>(node);
      }
      vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
116
    }
117
    return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
118
119
  }

120
  PrimExpr VisitExpr_(const BufferLoadNode *node) final {
121
122
123
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
124
    if (node->buffer->shape.size() == 1) {
125
126
      // TODO(lei): This should be improved as
      // constant buffer that tl hack to use as local register.
127
128
      auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
      if (boundary_check && boundary_check->value == 1) {
129
        return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
130
      }
131
132
    }
    UpdateVectorSize(node->indices, node->buffer);
133
    return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
134
135
  }

136
  Stmt VisitStmt_(const BufferStoreNode *node) final {
137
138
139
140
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
    UpdateVectorSize(node->indices, node->buffer);
141
    return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
142
143
  }

144
  Stmt VisitStmt_(const IfThenElseNode *node) final {
145
    CheckConditionVectorized(node->condition);
146
    return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
147
148
  }

149
  PrimExpr VisitExpr_(const CallNode *node) final {
150
151
152
153
154
155
    if (node->op == builtin::if_then_else()) {
      CheckConditionVectorized(node->args[0]);
    } else if (node->op == builtin::call_extern()) {
      // do not vectorize extern calls
      vector_size_ = 1;
    }
156
    return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
157
158
  }

159
  void CheckConditionVectorized(const PrimExpr &cond) {
160
161
162
    // TODO: perform some checks here
  }

163
  PrimExpr VisitExpr_(const CastNode *node) final {
164
165
    vector_size_ = arith::ZeroAwareGCD(
        vector_load_bits_max_ / node->dtype.bits(), vector_size_);
166
    return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
167
168
169
  }

  void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
170
171
    if (!inner_for_)
      return;
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    // 1. Compute raw element offset
    auto strides = buffer->strides;
    if (buffer->strides.empty()) {
      PrimExpr stride = 1;
      for (int i = indices.size() - 1; i >= 0; --i) {
        strides.push_back(stride);
        stride = stride * buffer->shape[i];
      }
      strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
    }
    PrimExpr elem_offset = 0;
    for (int i = 0; i < indices.size(); ++i) {
      elem_offset += indices[i] * strides[i];
    }
    // 2. If element offset is independent with loop_var, ignore it
187
    if (CanProveIndependent(elem_offset, inner_for_->loop_var, analyzer_)) {
188
      return;
189
    }
190
191
192
193
194
195
196
    // 3. Check if current vector_size_ works with invariant boundary check
    if (!IsExprInvariantInVectorBoundary(elem_offset, inner_for_->loop_var,
                                         vector_size_, analyzer_)) {
      // If not, tight vectorize bound with buffer dtype constraint
      vector_size_ = arith::ZeroAwareGCD(
          vector_size_, vector_load_bits_max_ / buffer->dtype.bits());
    }
197
198
    // 4. Try to vectorize buffer load
    while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
199
                               inner_for_->extent, vector_size_, analyzer_)) {
200
      vector_size_ /= 2;
201
202
203
    }
  }

204
  int vector_load_bits_max_;
205

206
  const ForNode *inner_for_{};
207
208
209
210
211
  bool has_nonlocal_memory_access_ = false;
  int vector_size_ = 128;
};

class VectorizeRewriter : public StmtExprMutator {
212
public:
213
  VectorizeRewriter(int vector_size) : vector_size_(vector_size) {}
214

215
216
private:
  Stmt VisitStmt_(const ForNode *node) final {
217
218
    inner_for_ = node;
    auto ret = StmtExprMutator::VisitStmt_(node);
219
    if (inner_for_ == node) { // rewrite the innermost loop
220
221
222
223
224
225
226
227
      For fnode = ret.as<For>().value();
      auto old_var = fnode->loop_var;
      auto extent_ptr = as_const_int(fnode->extent);
      ICHECK(extent_ptr) << fnode->extent;
      int extent = *extent_ptr;
      ICHECK(extent % vector_size_ == 0)
          << "extent: " << extent << " vector_size_: " << vector_size_;
      ICHECK(is_zero(fnode->min));
228
229
      if (extent == vector_size_) {
        fnode.CopyOnWrite()->kind = ForKind::kVectorized;
230
        return fnode;
231
232
233
234
235
236
237
238
      } else {
        Var inner_var = Var("vec");
        Var outer_var = Var(old_var->name_hint);
        Map<Var, PrimExpr> vmap;
        vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
        Stmt body = Substitute(fnode->body, vmap);
        body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
        body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
239
240
                   fnode->thread_binding, fnode->annotations, fnode->step,
                   fnode->span);
241
        return body;
242
243
244
245
246
247
      }
    } else {
      return ret;
    }
  }

248
  const ForNode *inner_for_{};
249
250
251
  const int vector_size_;
};

252
253
254
255
256
257
258
259
int GetVectorizeSize(const For &loop) {
  arith::Analyzer analyzer;
  return VectorizePlanner(&analyzer).Plan(loop);
}

int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer) {
  return VectorizePlanner(analyzer).Plan(loop);
}
260

261
262
263
bool CanProveIndependent(const PrimExpr &expr, Var var,
                         arith::Analyzer *analyzer) {
  // 1. if var doesn't exist, it is independent
264
265
266
  bool used_var = UsesVar(expr, [&](const VarNode *v) {
    return tvm::ffi::GetRef<Var>(v).same_as(var);
  });
267
268
269
270
271
272
273
274
275
276
  if (!used_var) {
    return true;
  }
  // 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v
  Var var_1("_t", var.dtype());
  auto expr_1 = Substitute(expr, {{var, var_1}});
  if (analyzer->CanProveEqual(expr, expr_1)) {
    return true;
  }
  return false;
277
278
}

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
bool IsExprInvariantInVectorBoundary(const PrimExpr &expr, Var var,
                                     int target_vectorized_size,
                                     arith::Analyzer *analyzer) {
  // Check if expr is invariant within vector boundaries
  // We're trying to prove the access expression A[f(var)] depends only on
  // floor(var/vecsize), not on var%vecsize
  // Mathematically:
  // \forall var, f(floor(var/vecsize)*vecsize + var%vecsize) ==
  // f(floor(var/vecsize)*vecsize + 0)
  // Example: for i in T.vectorized(8):
  //     A[i] = B[i] * C[i//4]
  // if vecsize=4, f(i)=i//4 depends only on i//4
  // Therefore A[i] = B[i] * C[i//4] can be vectorized with vecsize=4
  PrimExpr var_aligned =
      floordiv(var, target_vectorized_size) * target_vectorized_size;
  PrimExpr expr_aligned = Substitute(expr, {{var, var_aligned}});
  if (analyzer->CanProveEqual(expr, expr_aligned)) {
    return true;
  }
  return false;
}

301
302
bool IndiceCanVectorize(const PrimExpr &expr, Var var,
                        const PrimExpr &iter_var_size,
303
                        int target_vectorized_size, arith::Analyzer *analyzer) {
304
  ICHECK(target_vectorized_size >= 1);
305
306
  if (target_vectorized_size == 1)
    return true;
307
308

  // Extent must be divisible
309
310
311
312
313
314
315
316
317
  PrimExpr target_size_for_iter =
      make_const(iter_var_size.dtype(), target_vectorized_size);
  PrimExpr target_size_for_expr =
      make_const(expr.dtype(), target_vectorized_size);
  PrimExpr target_size_for_var =
      make_const(var.dtype(), target_vectorized_size);
  PrimExpr zero = make_const(var.dtype(), 0);

  if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter),
318
319
                               0))
    return false;
320

321
322
  if (IsExprInvariantInVectorBoundary(expr, var, target_vectorized_size,
                                      analyzer)) {
323
324
325
    return true;
  }

326
  auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}}));
327
  // The base offset must be divisible
328
329
  if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr),
                               zero)) {
330
331
332
333
    return false;
  }

  // Bind thread range
334
335
336
337
  Var v0("v0", var.dtype()), v1("v1", var.dtype());
  analyzer->Bind(v0, Range(zero, target_size_for_var));
  analyzer->Bind(v1, Range(zero, analyzer->Simplify(FloorDiv(
                                     iter_var_size, target_size_for_iter))));
338
  PrimExpr expr_transformed = analyzer->Simplify(
339
340
      Substitute(expr, {{var, v0 + v1 * target_size_for_var}}));
  Vectorizer vectorizer(v0, target_size_for_var);
341
  PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
342
343

  // This simplify is necessary for thread region specified
344
345
  // optimizations.
  expr_vectorized = analyzer->Simplify(expr_vectorized);
346
347
348
349
350
351
352
353
354
355
356
357
  auto ramp_node = expr_vectorized.as<RampNode>();
  if (!ramp_node) {
    // Broadcast value
    if (expr_vectorized.dtype().lanes() == 1)
      return true;
    else
      return false;
  } else {
    return is_one(ramp_node->stride);
  }
}

358
For VectorizeLoop(const For &loop, int vectorize_hint) {
359
  if (vectorize_hint <= 0) {
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    arith::Analyzer analyzer;
    VectorizePlanner planner(&analyzer);
    vectorize_hint = planner.Plan(loop);
  }
  if (vectorize_hint == 1)
    return loop;
  auto rewriter = VectorizeRewriter(vectorize_hint);
  return Downcast<For>(rewriter(loop));
}

For VectorizeLoop(const For &loop, arith::Analyzer *analyzer,
                  int vectorize_hint) {
  if (vectorize_hint <= 0) {
    VectorizePlanner planner(analyzer);
374
    vectorize_hint = planner.Plan(loop);
375
  }
376
377
  if (vectorize_hint == 1)
    return loop;
378
  auto rewriter = VectorizeRewriter(vectorize_hint);
379
380
381
  return Downcast<For>(rewriter(loop));
}

382
383
} // namespace tl
} // namespace tvm