"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "3d57dd73703c4f9978b45ad82bab7f9a71b01cc9"
loop_vectorize.cc 10.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file loop_vectorize.cc
 * \brief A tool to automatically vectorize a for loop
 */

#include "loop_vectorize.h"

#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>

#include <numeric>

#include "../layout/layout.h"
#include "../layout/utils.h"
35
36
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include "common/loop_vectorization_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

struct VectorizePlanResult {
  int vector_size;
  bool dynamic;
  PrimExpr condition;
};

class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
51
public:
52
53
  VectorizePlanner() = default;

54
  int Plan(const For &node) {
55
56
57
58
59
60
61
62
63
64
    this->operator()(node);
    // Always Enable vectorization
    // if (!has_nonlocal_memory_access_) return 1;
    return vector_size_;
  }

  bool GetDynamic() { return dynamic_; }

  PrimExpr GetCondition() { return condition_; }

65
66
private:
  void VisitStmt_(const ForNode *node) final {
67
68
69
70
71
    inner_for_ = node;
    iter_map_.Set(node->loop_var, Range(node->min, node->extent));
    arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

72
  void VisitExpr_(const BufferLoadNode *node) final {
73
74
75
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
76
77
    if (node->buffer->shape.size() == 1 &&
        node->buffer->shape[0].as<IntImmNode>()->value == 1) {
78
79
80
81
82
83
84
85
      // TODO(lei): This should be improved as
      // constant buffer that tl hack to use as local register.
      return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
    }
    UpdateVectorSize(node->indices, node->buffer);
    return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
  }

86
  void VisitStmt_(const BufferStoreNode *node) final {
87
88
89
90
91
92
93
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
    UpdateVectorSize(node->indices, node->buffer);
    return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

94
  void VisitStmt_(const IfThenElseNode *node) final {
95
96
97
98
    CheckConditionVectorized(node->condition);
    return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

99
  void VisitExpr_(const CallNode *node) final {
100
101
102
103
104
105
106
107
108
    if (node->op == builtin::if_then_else()) {
      CheckConditionVectorized(node->args[0]);
    } else if (node->op == builtin::call_extern()) {
      // do not vectorize extern calls
      vector_size_ = 1;
    }
    return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
  }

109
  void CheckConditionVectorized(const PrimExpr &cond) {
110
111
112
    // TODO: perform some checks here
  }

113
114
115
  void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
    if (!inner_for_)
      return;
116
    auto extent_ptr = inner_for_->extent.as<IntImmNode>();
117
118
    if (!extent_ptr)
      return;
119

120
    const DataType &access_type = buffer->dtype;
121
    // i // 2, i % 8 can also be vectorized as factor 16
122
    int max_vector_size = vector_load_bits_max_ / access_type.bits();
123
124
125
126
127
    // so we should disable this GCD optimization
    max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);

    auto last_dim = buffer->shape.back();
    auto mod_set = analyzer_.modular_set(last_dim);
128
129
    // when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
    // conditionally tail vectorize
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    if (buffer->shape.back().as<IntImmNode>()) {
      max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);

      auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
      // If gcd_base is equal to the last dimension,
      // we should analyze the second-to-last dimension
      // in relation to the last dimension.
      if (gcd_base < Downcast<IntImm>(last_dim)->value) {
        max_vector_size = gcd_base;
      }

      vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);

      PrimExpr elem_offset = 0;
      PrimExpr stride = 1;
      for (int i = indices.size() - 1; i >= 0; --i) {
        elem_offset = elem_offset + indices[i] * stride;
        stride = stride * buffer->shape[i];
      }
149
150
151
      while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
                                 inner_for_->extent, vector_size_,
                                 &analyzer_)) {
152
153
154
155
156
157
158
159
160
161
        vector_size_ /= 2;
      }
    } else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
      // dynamic shape load: get the vectorization condition
      dynamic_ = true;
      PrimExpr offset = buffer.OffsetOf(indices).back();
      condition_ = (FloorMod(offset, vector_size_) == 0);
    }
  }

162
  const int vector_load_bits_max_ = 128;
163

164
  const ForNode *inner_for_;
165
166
167
168
169
170
171
172
173
  Map<Var, Range> iter_map_;
  bool has_nonlocal_memory_access_ = false;
  int vector_size_ = 128;
  // conditionally vectorize
  bool dynamic_ = false;
  PrimExpr condition_;
};

class VectorizeDynamicCallRemover : public StmtExprMutator {
174
public:
175
176
177
  VectorizeDynamicCallRemover(Var inner_var, int vector_size)
      : inner_var_(inner_var), vector_size_(vector_size) {}

178
179
private:
  PrimExpr VisitExpr_(const CallNode *op) final {
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    if (op->op.same_as(builtin::if_then_else())) {
      PrimExpr cond = this->VisitExpr(op->args[0]);
      Map<Var, PrimExpr> vmap;
      // Currently remove upper bound check
      vmap.Set(inner_var_, 0);
      cond = Substitute(cond, vmap);
      Array<PrimExpr> new_args{cond, op->args[1], op->args[2]};
      return Call(op->dtype, op->op, new_args, op->span);
    } else {
      // TODO: For other calls
      return GetRef<PrimExpr>(op);
    }
  }

  Var inner_var_;
  int vector_size_;
};

class VectorizeRewriter : public StmtExprMutator {
199
public:
200
  VectorizeRewriter(VectorizePlanResult plan)
201
202
      : vector_size_(plan.vector_size), condition_(plan.condition),
        dynamic_(plan.dynamic) {}
203

204
205
private:
  Stmt VisitStmt_(const ForNode *node) final {
206
207
    inner_for_ = node;
    auto ret = StmtExprMutator::VisitStmt_(node);
208
    if (inner_for_ == node) { // rewrite the innermost loop
209
210
211
212
213
214
215
216
      For fnode = ret.as<For>().value();
      auto old_var = fnode->loop_var;
      auto extent_ptr = as_const_int(fnode->extent);
      ICHECK(extent_ptr) << fnode->extent;
      int extent = *extent_ptr;
      ICHECK(extent % vector_size_ == 0)
          << "extent: " << extent << " vector_size_: " << vector_size_;
      ICHECK(is_zero(fnode->min));
217
      if (!dynamic_) { // check dynamic shape
218
219
220
221
222
223
224
225
226
227
        if (extent == vector_size_) {
          fnode.CopyOnWrite()->kind = ForKind::kVectorized;
          return fnode;
        } else {
          Var inner_var = Var("vec");
          Var outer_var = Var(old_var->name_hint);
          Map<Var, PrimExpr> vmap;
          vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
          Stmt body = Substitute(fnode->body, vmap);
          body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
228
229
          body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
                     fnode->thread_binding, fnode->annotations, fnode->span);
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
          return body;
        }
      } else {
        Var inner_var = Var("vec");
        Var outer_var = Var(old_var->name_hint);
        Map<Var, PrimExpr> vmap;
        vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
        Stmt body = Substitute(fnode->body, vmap);
        // add condition ifthenelse here
        Map<Var, PrimExpr> vmap_condition;
        vmap_condition.Set(fnode->loop_var, outer_var * vector_size_);
        PrimExpr condition = Substitute(condition_, vmap_condition);

        VectorizeDynamicCallRemover remover(inner_var, vector_size_);
        body = remover(body);

246
247
248
249
        For vectorize_for =
            For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
        For serial_for =
            For(inner_var, 0, vector_size_, ForKind::kSerial, body);
250
        body = IfThenElse(condition, vectorize_for, serial_for);
251
252
        body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
                   fnode->thread_binding, fnode->annotations, fnode->span);
253
254
255
256
257
258
259
        return body;
      }
    } else {
      return ret;
    }
  }

260
  const ForNode *inner_for_;
261
262
263
264
265
  const int vector_size_;
  const PrimExpr condition_;
  const bool dynamic_;
};

266
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
267

268
VectorizePlanResult GetVectorizePlanResult(const For &loop) {
269
270
271
272
273
274
275
  VectorizePlanner planner;
  int vector_size = planner.Plan(loop);
  bool dynamic = planner.GetDynamic();
  PrimExpr condition = planner.GetCondition();
  return {vector_size, dynamic, condition};
}

276
277
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
                        int target_vectorized_size, arith::Analyzer *analyzer) {
278
  ICHECK(target_vectorized_size >= 1);
279
280
281
282
283
  if (target_vectorized_size == 1)
    return true;
  if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
                               0))
    return false;
284
285
286
  Var v0("v0"), v1("v1");
  analyzer->Bind(v0, Range(0, target_vectorized_size));
  analyzer->Bind(v1, Range(0, FloorDiv(iter_var_size, target_vectorized_size)));
287
288
  PrimExpr expr_transformed = analyzer->Simplify(
      Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303

  Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
  PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
  auto ramp_node = expr_vectorized.as<RampNode>();
  if (!ramp_node) {
    // Broadcast value
    if (expr_vectorized.dtype().lanes() == 1)
      return true;
    else
      return false;
  } else {
    return is_one(ramp_node->stride);
  }
}

304
For VectorizeLoop(const For &loop, int vectorize_hint) {
305
306
307
308
309
  VectorizePlanResult res{128, false, 0};
  if (vectorize_hint <= 0) {
    res = GetVectorizePlanResult(loop);
    vectorize_hint = res.vector_size;
  }
310
311
  if (vectorize_hint == 1)
    return loop;
312
313
314
315
  auto rewriter = VectorizeRewriter(res);
  return Downcast<For>(rewriter(loop));
}

316
317
} // namespace tl
} // namespace tvm