loop_vectorize.cc 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file loop_vectorize.cc
 * \brief A tool to automatically vectorize a for loop
 */

#include "loop_vectorize.h"
26
27
#include "../op/builtin.h"
#include "../target/utils.h"
28
29
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
30
#include "common/loop_vectorization_utils.h"
31
32
33
34
35
#include "tvm/tir/analysis.h"
#include "tvm/tir/var.h"
#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
36
37
38
39
40
41
42
43
44
45
46
47

namespace tvm {
namespace tl {

using namespace tir;

struct VectorizePlanResult {
  int vector_size;
  bool dynamic;
  PrimExpr condition;
};

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer {
public:
  VectorizeFindGlobalAccess() = default;

  bool HasGlobalAccess(const Stmt &stmt) {
    this->operator()(stmt);
    return has_global_access_;
  }

private:
  bool has_global_access_ = false;

  void VisitStmt_(const BufferStoreNode *node) final {
    if (node->buffer.scope() == "global")
      has_global_access_ = true;
    return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

  void VisitExpr_(const BufferLoadNode *node) final {
    if (node->buffer.scope() == "global")
      has_global_access_ = true;
    return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
  }
};

73
class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
74
public:
75
76
  VectorizePlanner() = default;

77
  int Plan(const For &node) {
78
79
80
81
82
83
84
85
86
87
88
89
    tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
    Optional<Bool> opt_disable_vectorize_256 =
        ctxt->GetConfig(kDisableVectorize256, Optional<Bool>());
    bool disable_vectorize_256 =
        opt_disable_vectorize_256.value_or(Bool(false));
    if (tvm::tl::TargetIsSm100(Target::Current(false)) &&
        !disable_vectorize_256 &&
        VectorizeFindGlobalAccess().HasGlobalAccess(node)) {
      vector_load_bits_max_ = vector_size_ = 256;
    } else {
      vector_load_bits_max_ = vector_size_ = 128;
    }
90
91
92
93
    this->operator()(node);
    return vector_size_;
  }

94
95
private:
  void VisitStmt_(const ForNode *node) final {
96
    inner_for_ = node;
97
    auto extent_ptr = as_const_int(analyzer_.Simplify(node->extent));
98
99
100
101
102
103
104
105
    // Here I disable dynamic shape completely,
    //   In order to do it, the Planner should accept an analyzer with
    //   arithmetic info outside to prove the dividiblity of vector size
    if (!extent_ptr) {
      vector_size_ = 1;
      return;
    }
    vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
106
107
108
    arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

109
  void VisitExpr_(const BufferLoadNode *node) final {
110
111
112
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
113
    if (node->buffer->shape.size() == 1) {
114
115
      // TODO(lei): This should be improved as
      // constant buffer that tl hack to use as local register.
116
117
118
119
      auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
      if (boundary_check && boundary_check->value == 1) {
        return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
      }
120
121
122
123
    }
    UpdateVectorSize(node->indices, node->buffer);
  }

124
  void VisitStmt_(const BufferStoreNode *node) final {
125
126
127
128
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
    UpdateVectorSize(node->indices, node->buffer);
129
    return arith::IRVisitorWithAnalyzer::VisitExpr(node->value);
130
131
  }

132
  void VisitStmt_(const IfThenElseNode *node) final {
133
134
135
136
    CheckConditionVectorized(node->condition);
    return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

137
  void VisitExpr_(const CallNode *node) final {
138
139
140
141
142
143
144
145
146
    if (node->op == builtin::if_then_else()) {
      CheckConditionVectorized(node->args[0]);
    } else if (node->op == builtin::call_extern()) {
      // do not vectorize extern calls
      vector_size_ = 1;
    }
    return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
  }

147
  void CheckConditionVectorized(const PrimExpr &cond) {
148
149
150
    // TODO: perform some checks here
  }

151
152
153
154
155
156
157
  void VisitExpr_(const CastNode *node) final {
    vector_size_ = arith::ZeroAwareGCD(
        vector_load_bits_max_ / node->dtype.bits(), vector_size_);
    return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
  }

  void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
158
159
    if (!inner_for_)
      return;
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    // 1. Compute raw element offset
    auto strides = buffer->strides;
    if (buffer->strides.empty()) {
      PrimExpr stride = 1;
      for (int i = indices.size() - 1; i >= 0; --i) {
        strides.push_back(stride);
        stride = stride * buffer->shape[i];
      }
      strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
    }
    PrimExpr elem_offset = 0;
    for (int i = 0; i < indices.size(); ++i) {
      elem_offset += indices[i] * strides[i];
    }

    // 2. If element offset is independent with loop_var, ignore it
    if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) {
177
      return;
178
    }
179

180
181
182
    // 3. Tight vectorize bound
    vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ /
                                                         buffer->dtype.bits());
183

184
185
186
187
    // 4. Try to vectorize buffer load
    while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
                               inner_for_->extent, vector_size_, &analyzer_)) {
      vector_size_ /= 2;
188
189
190
    }
  }

191
  int vector_load_bits_max_;
192

193
  const ForNode *inner_for_{};
194
195
196
197
198
  bool has_nonlocal_memory_access_ = false;
  int vector_size_ = 128;
};

class VectorizeRewriter : public StmtExprMutator {
199
public:
200
  VectorizeRewriter(int vector_size) : vector_size_(vector_size) {}
201

202
203
private:
  Stmt VisitStmt_(const ForNode *node) final {
204
205
    inner_for_ = node;
    auto ret = StmtExprMutator::VisitStmt_(node);
206
    if (inner_for_ == node) { // rewrite the innermost loop
207
208
209
210
211
212
213
214
      For fnode = ret.as<For>().value();
      auto old_var = fnode->loop_var;
      auto extent_ptr = as_const_int(fnode->extent);
      ICHECK(extent_ptr) << fnode->extent;
      int extent = *extent_ptr;
      ICHECK(extent % vector_size_ == 0)
          << "extent: " << extent << " vector_size_: " << vector_size_;
      ICHECK(is_zero(fnode->min));
215
216
      if (extent == vector_size_) {
        fnode.CopyOnWrite()->kind = ForKind::kVectorized;
217
        return fnode;
218
219
220
221
222
223
224
225
226
227
      } else {
        Var inner_var = Var("vec");
        Var outer_var = Var(old_var->name_hint);
        Map<Var, PrimExpr> vmap;
        vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
        Stmt body = Substitute(fnode->body, vmap);
        body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
        body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
                   fnode->thread_binding, fnode->annotations, fnode->span);
        return body;
228
229
230
231
232
233
      }
    } else {
      return ret;
    }
  }

234
  const ForNode *inner_for_{};
235
236
237
  const int vector_size_;
};

238
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
239

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
bool CanProveIndependent(const PrimExpr &expr, Var var,
                         arith::Analyzer *analyzer) {
  // 1. if var doesn't exist, it is independent
  bool used_var = UsesVar(
      expr, [&](const VarNode *v) { return GetRef<Var>(v).same_as(var); });
  if (!used_var) {
    return true;
  }
  // 2. if \forall v_1, v_2, f(v_1) == f(v_2), f is independent with v
  Var var_1("_t", var.dtype());
  auto expr_1 = Substitute(expr, {{var, var_1}});
  if (analyzer->CanProveEqual(expr, expr_1)) {
    return true;
  }
  return false;
255
256
}

257
258
bool IndiceCanVectorize(const PrimExpr &expr, Var var,
                        const PrimExpr &iter_var_size,
259
                        int target_vectorized_size, arith::Analyzer *analyzer) {
260
  ICHECK(target_vectorized_size >= 1);
261
262
  if (target_vectorized_size == 1)
    return true;
263
264

  // Extent must be divisible
265
266
267
  if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
                               0))
    return false;
268
269
270
271
272
273
274
275

  // The base offset must be divisible
  if (!analyzer->CanProveEqual(
          FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) {
    return false;
  }

  // Bind thread range
276
277
  Var v0("v0"), v1("v1");
  analyzer->Bind(v0, Range(0, target_vectorized_size));
278
279
  analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv(
                                  iter_var_size, target_vectorized_size))));
280
281
  PrimExpr expr_transformed = analyzer->Simplify(
      Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
282
  Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
283
  PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
284
285

  // This simplify is necessary for thread region specified
286
287
  // optimizations.
  expr_vectorized = analyzer->Simplify(expr_vectorized);
288
289
290
291
292
293
294
295
296
297
298
299
  auto ramp_node = expr_vectorized.as<RampNode>();
  if (!ramp_node) {
    // Broadcast value
    if (expr_vectorized.dtype().lanes() == 1)
      return true;
    else
      return false;
  } else {
    return is_one(ramp_node->stride);
  }
}

300
For VectorizeLoop(const For &loop, int vectorize_hint) {
301
  if (vectorize_hint <= 0) {
302
303
    VectorizePlanner planner;
    vectorize_hint = planner.Plan(loop);
304
  }
305
306
  if (vectorize_hint == 1)
    return loop;
307
  auto rewriter = VectorizeRewriter(vectorize_hint);
308
309
310
  return Downcast<For>(rewriter(loop));
}

311
312
} // namespace tl
} // namespace tvm