loop_vectorize.cc 9.15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file loop_vectorize.cc
 * \brief A tool to automatically vectorize a for loop
 */

#include "loop_vectorize.h"

#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>

#include <numeric>

#include "../layout/layout.h"
#include "../layout/utils.h"
35
36
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
37
38
39
40
41
42
43
44
45
46
47
48
49
50
#include "common/loop_vectorization_utils.h"

namespace tvm {
namespace tl {

using namespace tir;

struct VectorizePlanResult {
  int vector_size;
  bool dynamic;
  PrimExpr condition;
};

class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
51
public:
52
53
  VectorizePlanner() = default;

54
  int Plan(const For &node) {
55
56
57
58
59
60
61
62
    this->operator()(node);
    return vector_size_;
  }

  bool GetDynamic() { return dynamic_; }

  PrimExpr GetCondition() { return condition_; }

63
64
private:
  void VisitStmt_(const ForNode *node) final {
65
66
67
68
69
    inner_for_ = node;
    iter_map_.Set(node->loop_var, Range(node->min, node->extent));
    arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

70
  void VisitExpr_(const BufferLoadNode *node) final {
71
72
73
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
74
    if (node->buffer->shape.size() == 1) {
75
76
      // TODO(lei): This should be improved as
      // constant buffer that tl hack to use as local register.
77
78
79
80
      auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
      if (boundary_check && boundary_check->value == 1) {
        return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
      }
81
82
83
84
    }
    UpdateVectorSize(node->indices, node->buffer);
  }

85
  void VisitStmt_(const BufferStoreNode *node) final {
86
87
88
89
    if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
        node->buffer.scope() == "shared.dyn")
      has_nonlocal_memory_access_ = true;
    UpdateVectorSize(node->indices, node->buffer);
90
    return arith::IRVisitorWithAnalyzer::VisitExpr(node->value);
91
92
  }

93
  void VisitStmt_(const IfThenElseNode *node) final {
94
95
96
97
    CheckConditionVectorized(node->condition);
    return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
  }

98
  void VisitExpr_(const CallNode *node) final {
99
100
101
102
103
104
105
106
107
    if (node->op == builtin::if_then_else()) {
      CheckConditionVectorized(node->args[0]);
    } else if (node->op == builtin::call_extern()) {
      // do not vectorize extern calls
      vector_size_ = 1;
    }
    return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
  }

108
  void CheckConditionVectorized(const PrimExpr &cond) {
109
110
111
    // TODO: perform some checks here
  }

112
113
114
  void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
    if (!inner_for_)
      return;
115
    auto extent_ptr = inner_for_->extent.as<IntImmNode>();
116
117
    if (!extent_ptr)
      return;
118

119
    const DataType &access_type = buffer->dtype;
120
    // i // 2, i % 8 can also be vectorized as factor 16
121
    int max_vector_size = vector_load_bits_max_ / access_type.bits();
122
123
124
    if (access_type.is_e4m3_float8() or access_type.is_e5m2_float8()) {
      max_vector_size = 1; // [temporarily] do not vectorize float8
    }
125
126
127
128
    // so we should disable this GCD optimization
    max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
    auto last_dim = buffer->shape.back();
    auto mod_set = analyzer_.modular_set(last_dim);
129
130
    // when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
    // conditionally tail vectorize
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    if (buffer->shape.back().as<IntImmNode>()) {
      max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
      auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
      // If gcd_base is equal to the last dimension,
      // we should analyze the second-to-last dimension
      // in relation to the last dimension.
      if (gcd_base < Downcast<IntImm>(last_dim)->value) {
        max_vector_size = gcd_base;
      }
      vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);

      PrimExpr elem_offset = 0;
      PrimExpr stride = 1;
      for (int i = indices.size() - 1; i >= 0; --i) {
        elem_offset = elem_offset + indices[i] * stride;
        stride = stride * buffer->shape[i];
      }
148
149
150
      while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
                                 inner_for_->extent, vector_size_,
                                 &analyzer_)) {
151
152
153
154
155
156
157
158
159
160
        vector_size_ /= 2;
      }
    } else if (vector_size_ <= vector_load_bits_max_ / buffer->dtype.bits()) {
      // dynamic shape load: get the vectorization condition
      dynamic_ = true;
      PrimExpr offset = buffer.OffsetOf(indices).back();
      condition_ = (FloorMod(offset, vector_size_) == 0);
    }
  }

161
  const int vector_load_bits_max_ = 128;
162

163
  const ForNode *inner_for_;
164
165
166
167
168
169
170
171
172
  Map<Var, Range> iter_map_;
  bool has_nonlocal_memory_access_ = false;
  int vector_size_ = 128;
  // conditionally vectorize
  bool dynamic_ = false;
  PrimExpr condition_;
};

class VectorizeRewriter : public StmtExprMutator {
173
public:
174
  VectorizeRewriter(VectorizePlanResult plan)
175
176
      : vector_size_(plan.vector_size), condition_(plan.condition),
        dynamic_(plan.dynamic) {}
177

178
179
private:
  Stmt VisitStmt_(const ForNode *node) final {
180
181
    inner_for_ = node;
    auto ret = StmtExprMutator::VisitStmt_(node);
182
    if (inner_for_ == node) { // rewrite the innermost loop
183
184
185
186
187
188
189
190
      For fnode = ret.as<For>().value();
      auto old_var = fnode->loop_var;
      auto extent_ptr = as_const_int(fnode->extent);
      ICHECK(extent_ptr) << fnode->extent;
      int extent = *extent_ptr;
      ICHECK(extent % vector_size_ == 0)
          << "extent: " << extent << " vector_size_: " << vector_size_;
      ICHECK(is_zero(fnode->min));
191
      if (!dynamic_) { // check dynamic shape
192
193
194
195
196
197
198
199
200
201
        if (extent == vector_size_) {
          fnode.CopyOnWrite()->kind = ForKind::kVectorized;
          return fnode;
        } else {
          Var inner_var = Var("vec");
          Var outer_var = Var(old_var->name_hint);
          Map<Var, PrimExpr> vmap;
          vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
          Stmt body = Substitute(fnode->body, vmap);
          body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
202
203
          body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
                     fnode->thread_binding, fnode->annotations, fnode->span);
204
205
206
          return body;
        }
      } else {
207
        return fnode;
208
209
210
211
212
213
      }
    } else {
      return ret;
    }
  }

214
  const ForNode *inner_for_;
215
216
217
218
219
  const int vector_size_;
  const PrimExpr condition_;
  const bool dynamic_;
};

220
int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
221

222
VectorizePlanResult GetVectorizePlanResult(const For &loop) {
223
224
225
226
227
228
229
  VectorizePlanner planner;
  int vector_size = planner.Plan(loop);
  bool dynamic = planner.GetDynamic();
  PrimExpr condition = planner.GetCondition();
  return {vector_size, dynamic, condition};
}

230
231
bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
                        int target_vectorized_size, arith::Analyzer *analyzer) {
232
  ICHECK(target_vectorized_size >= 1);
233
234
235
236
237
  if (target_vectorized_size == 1)
    return true;
  if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
                               0))
    return false;
238
239
  Var v0("v0"), v1("v1");
  analyzer->Bind(v0, Range(0, target_vectorized_size));
240
241
  analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv(
                                  iter_var_size, target_vectorized_size))));
242
243
  PrimExpr expr_transformed = analyzer->Simplify(
      Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
244
245
  PrimExpr expr_simplified = analyzer->Simplify(expr_transformed);

246
  Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
247
248
  PrimExpr expr_vectorized =
      analyzer->Simplify(vectorizer.VisitExpr(expr_transformed));
249

250
251
252
253
254
255
256
257
258
259
260
261
  auto ramp_node = expr_vectorized.as<RampNode>();
  if (!ramp_node) {
    // Broadcast value
    if (expr_vectorized.dtype().lanes() == 1)
      return true;
    else
      return false;
  } else {
    return is_one(ramp_node->stride);
  }
}

262
For VectorizeLoop(const For &loop, int vectorize_hint) {
263
264
265
266
267
  VectorizePlanResult res{128, false, 0};
  if (vectorize_hint <= 0) {
    res = GetVectorizePlanResult(loop);
    vectorize_hint = res.vector_size;
  }
268
269
  if (vectorize_hint == 1)
    return loop;
270
271
272
273
  auto rewriter = VectorizeRewriter(res);
  return Downcast<For>(rewriter(loop));
}

274
275
} // namespace tl
} // namespace tvm