vectorize_loop.cc 32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file vectorize_loop.cc
 */
// Loop vectorizer as in Halide pipeline.
#include <tvm/arith/analyzer.h>
25
26
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
27
28
29
30
31
32
33
34
35
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_map>
36
#include <unordered_set>
37
#include <utility>
38
39
40
41
42
43
44
45
46
#include <vector>

#include "arith/scalable_expression.h"
#include "tir/analysis/check_contains.h"

namespace tvm {
namespace tl {

using namespace tir;
47
using namespace ffi;
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

/*!
 * \brief Perform data type legalization on the given BufferLoadNode pointer.
 * Equal to BufferLoadNode::LegalizeDType, but operates on a pointer.
 * \param n A pointer to a writable BufferLoadNode.
 */
static void LegalizeBufferLoadDType(BufferLoadNode *n) {
  // Check that all indices except the last one have a scalar dtype
  for (int i = 0; i < static_cast<int>(n->indices.size()) - 1; i++) {
    ICHECK(n->indices[i].dtype().is_scalar())
        << "Only the last index of a buffer access may be a vector type.";
  }

  // If there are no indices, set the dtype to the buffer's dtype
  if (n->indices.empty()) {
    n->dtype = n->buffer->dtype;
  } else {
    auto index_dtype = n->indices.back().dtype();
    bool is_buffer_dtype_scalable = n->buffer->dtype.is_scalable_vector();
    bool is_index_scalable = index_dtype.is_scalable_vector();

    // Do not allow both index dtype and buffer dtype to be scalable vectors
    ICHECK(!(is_index_scalable && is_buffer_dtype_scalable))
        << "Index dtype and buffer dtype cannot both be scalable.";

    if (is_index_scalable) {
      // Index is a scalable vector, while the buffer is not
      n->dtype = n->buffer->dtype.with_scalable_vscale_factor(
          index_dtype.vscale_factor() * n->buffer->dtype.lanes());
    } else if (is_buffer_dtype_scalable) {
      // The buffer is a scalable vector, while the index is not
      n->dtype = n->buffer->dtype.with_scalable_vscale_factor(
          n->buffer->dtype.vscale_factor() * index_dtype.lanes());
    } else {
      // Neither side is a scalable vector, multiply lanes
      n->dtype = n->buffer->dtype.with_lanes(index_dtype.lanes() *
                                             n->buffer->dtype.lanes());
    }
  }
}

inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
  if (is_scalable) {
    return Mul(Call(DataType::Int(32), builtin::vscale(), {}),
               lanes_or_vscale_factor);
  } else {
    return lanes_or_vscale_factor;
  }
}

inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
  // Check if e is already in the expected form
  if (e.dtype().get_lanes_or_vscale_factor() == lanes &&
      e.dtype().is_scalable_vector() == is_scalable)
    return e;

  if (const BroadcastNode *op = e.as<BroadcastNode>()) {
    ICHECK(op->dtype.is_scalable_vector() == is_scalable)
        << "Can't broadcast between scalable and fixed length vectors.";
    int e_lanes = op->dtype.get_lanes_or_vscale_factor();

    if (lanes % e_lanes == 0) {
      return Broadcast(op->value, CreateNewLanes(is_scalable, lanes));
    }
  }

  ICHECK(e.dtype().is_scalar())
      << "Cannot broadcast lanes=" << e.dtype().get_lanes_or_vscale_factor()
      << " is_scalable=" << e.dtype().is_scalable_vector() << " to " << lanes;

  return Broadcast(e, CreateNewLanes(is_scalable, lanes));
}

// Rewrite vectorized allocation access
// This is necessary for making each vector component containing its own
// workspace. Originates from Halide's loop vectorizer
//
// s[i] = s[i * lanes + var]
//
// The same principle applies when using one thread to simulate multiple
// context.
//
Lei Wang's avatar
Lei Wang committed
130
class TLVecAllocAccess : public StmtExprMutator {
131
public:
Lei Wang's avatar
Lei Wang committed
132
  TLVecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes)
133
      : buf_(buf), var_(std::move(var)), var_lanes_(std::move(var_lanes)) {}
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
    auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    return UpdateBufferAccess(load);
  }

  Stmt VisitStmt_(const BufferStoreNode *op) final {
    auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    return UpdateBufferAccess(store);
  }

private:
  template <typename Node> Node UpdateBufferAccess(Node node) {
    // Only update the buffer that's being replaced.
    if (node->buffer->data.get() != buf_) {
      return node;
    }

    // Find/make a Buffer object with the correct updated shape.
    Buffer buf;
    auto it = buffer_map_.find(node->buffer.get());
    if (it != buffer_map_.end()) {
      buf = it->second;
    } else {
      // Extend the least significant dimension by a factor of
      // var_lanes_.  Typically, this will be a 1-d index into a flat
      // memory space.
      Array<PrimExpr> shape = node->buffer->shape;
      shape.Set(shape.size() - 1,
                analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_));

      // TODO(Lunderberg): Move this pass to be prior to
      // StorageFlatten/FlattenBuffer, implement by appending a
      // dimension to the buffer.  Since it is currently after the
      // flattening, the strides are not technically necessary, but
      // are updated for consistency.

      // Update strides if defined.
      Array<PrimExpr> strides;
      for (size_t i = 0; i < strides.size(); i++) {
        PrimExpr stride = strides[i];
        if (i != strides.size() - 1) {
          stride *= var_lanes_;
        }
        strides.push_back(analyzer_.Simplify(stride));
      }

      // Copy everything into the new buffer.
      buf = node->buffer;
      auto buf_writer = buf.CopyOnWrite();
      buf_writer->shape = shape;
      buf_writer->strides = strides;
      buffer_map_[buf.get()] = buf;
    }

    return node;
  }

  // buffer var
  const VarNode *buf_;
  // Updated buffer objects.
  std::unordered_map<const BufferNode *, Buffer> buffer_map_;
  // variable to be replaced
  Var var_;
  // the lanes.
  PrimExpr var_lanes_;
  // Analyzer for simplifications
  arith::Analyzer analyzer_;
};

// We use ExprFunctor directly instead of StmtExprMutator
// This is because the transformation can change the dtype of the Expr
// The existing ExprMutator transformation rules may not be well defined.
class TLVectorizer : public StmtMutator,
                     public ExprFunctor<PrimExpr(const PrimExpr &)> {
public:
  using ExprFunctor::VisitExpr;
  using StmtMutator::operator();

213
214
215
216
217
218
219
220
  // Convenience entry to vectorize a loop body without exposing
  // the mutator invocation pattern at call sites.
  static Stmt Vectorize(const Var &var, const PrimExpr &var_lanes, Stmt body) {
    TLVectorizer vec{var, var_lanes};
    auto vec_stmt = vec(std::move(body));
    return vec_stmt;
  }

221
222
  TLVectorizer(const Var &var, const PrimExpr &var_lanes)
      : var_(var), var_lanes_(var_lanes) {
223
224
225
226
227
228
229
    ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
  }

  Stmt VisitStmt(const Stmt &stmt) final {
    ICHECK(!need_scalarize_);
    Stmt ret = StmtMutator::VisitStmt(stmt);
    if (need_scalarize_) {
230
      auto scalarized_stmt = Scalarize(stmt);
231
      need_scalarize_ = false;
232
      return scalarized_stmt;
233
234
235
236
237
238
239
240
241
242
    } else {
      return ret;
    }
  }

  PrimExpr VisitExpr(const PrimExpr &e) final {
    return ExprFunctor::VisitExpr(e);
  }

  PrimExpr VisitExpr_(const AddNode *op) final {
243
244
    return AddSubVec(
        op, [](PrimExpr a, PrimExpr b) { return std::move(a) + std::move(b); });
245
246
247
  }

  PrimExpr VisitExpr_(const SubNode *op) final {
248
249
    return AddSubVec(
        op, [](PrimExpr a, PrimExpr b) { return std::move(a) - std::move(b); });
250
251
252
253
254
255
  }

  PrimExpr VisitExpr_(const MulNode *op) final {
    PrimExpr a = this->VisitExpr(op->a);
    PrimExpr b = this->VisitExpr(op->b);
    if (a.same_as(op->a) && b.same_as(op->b)) {
256
      return tvm::ffi::GetRef<PrimExpr>(op);
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    } else {
      bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
      bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
      if (is_vec_a && is_vec_b) {
        // Let's not multiply scalable and fixed length vectors
        ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector())
            << "Fixed length and scalable vectors can't be mixed in "
               "multiplication.";
      }
      if (is_vec_a || is_vec_b) {
        const RampNode *b_ramp = b.as<RampNode>();
        const RampNode *a_ramp = a.as<RampNode>();
        if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) {
          PrimExpr lanes = a_ramp->lanes;
          return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes);
        }
        if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) {
          PrimExpr lanes = b_ramp->lanes;
          return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes);
        }
        int a_lanes = a.dtype().get_lanes_or_vscale_factor();
        int b_lanes = b.dtype().get_lanes_or_vscale_factor();
        int max_lanes = std::max(a_lanes, b_lanes);
        bool is_scalable =
            a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
        return Mul(BroadcastTo(a, max_lanes, is_scalable),
                   BroadcastTo(b, max_lanes, is_scalable));
      }
    }
    return BinaryVec<Mul>(op);
  }
  PrimExpr VisitExpr_(const DivNode *op) final { return BinaryVec<Div>(op); }
  PrimExpr VisitExpr_(const ModNode *op) final { return BinaryVec<Mod>(op); }
  PrimExpr VisitExpr_(const FloorDivNode *op) final {
    return BinaryVec<FloorDiv>(op);
  }
  PrimExpr VisitExpr_(const FloorModNode *op) final {
    return BinaryVec<FloorMod>(op);
  }
  PrimExpr VisitExpr_(const MinNode *op) final { return BinaryVec<Min>(op); }
  PrimExpr VisitExpr_(const MaxNode *op) final { return BinaryVec<Max>(op); }
  PrimExpr VisitExpr_(const EQNode *op) final { return BinaryVec<EQ>(op); }
  PrimExpr VisitExpr_(const NENode *op) final { return BinaryVec<NE>(op); }
  PrimExpr VisitExpr_(const LTNode *op) final { return BinaryVec<LT>(op); }
  PrimExpr VisitExpr_(const LENode *op) final { return BinaryVec<LE>(op); }
  PrimExpr VisitExpr_(const GTNode *op) final { return BinaryVec<GT>(op); }
  PrimExpr VisitExpr_(const GENode *op) final { return BinaryVec<GE>(op); }
  PrimExpr VisitExpr_(const AndNode *op) final { return BinaryVec<And>(op); }
  PrimExpr VisitExpr_(const OrNode *op) final { return BinaryVec<Or>(op); }

  PrimExpr VisitExpr_(const NotNode *op) final {
    PrimExpr a = this->VisitExpr(op->a);
    if (a.same_as(op->a)) {
310
      return tvm::ffi::GetRef<PrimExpr>(op);
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    } else {
      return !(a);
    }
  }

  PrimExpr VisitExpr_(const RampNode *op) final {
    PrimExpr base = this->VisitExpr(op->base);
    PrimExpr stride = this->VisitExpr(op->stride);
    ICHECK(!base.dtype().is_scalable_vector())
        << "Creating scalable vectors from existing vectors is not supported.";
    ICHECK(!stride.dtype().is_scalable_vector())
        << "Ramp stride with scalable dtype is not supported";
    if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) {
      ICHECK(op->lanes->IsInstance<IntImmNode>())
          << "Vectorizing over existing scalable vectors is not supported.";
      const RampNode *base_ramp = base.as<RampNode>();
      int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
      int base_ramp_lanes =
          static_cast<int>(Downcast<IntImm>(base_ramp->lanes)->value);
      if (analyzer_.CanProve(base_ramp->stride ==
                             stride *
                                 make_const(stride.dtype(), base_ramp_lanes))) {
        return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes);
      }
    }
    int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
    base = BroadcastTo(base, lanes, false);
    stride = BroadcastTo(stride, lanes, false);
    Array<PrimExpr> elems;
    for (int i = 0; i < lanes; ++i) {
      elems.push_back(Ramp(Shuffle::ExtractElement(base, i),
                           Shuffle::ExtractElement(stride, i), op->lanes));
    }
    return Shuffle::Concat(elems);
  }

  PrimExpr VisitExpr_(const BroadcastNode *op) final {
    PrimExpr value = this->VisitExpr(op->value);
    if (value.dtype().is_scalable_or_fixed_length_vector()) {
      need_scalarize_ = true;
351
      return tvm::ffi::GetRef<PrimExpr>(op);
352
353
    }
    if (value.same_as(op->value)) {
354
      return tvm::ffi::GetRef<PrimExpr>(op);
355
356
357
358
359
360
361
362
363
364
365
    } else {
      return Broadcast(op->value, op->lanes);
    }
  }

  PrimExpr VisitExpr_(const SelectNode *op) final {
    PrimExpr cond = this->VisitExpr(op->condition);
    PrimExpr t = this->VisitExpr(op->true_value);
    PrimExpr f = this->VisitExpr(op->false_value);
    if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
        f.same_as(op->false_value)) {
366
      return tvm::ffi::GetRef<PrimExpr>(op);
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    } else {
      int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
      int t_lanes = t.dtype().get_lanes_or_vscale_factor();
      int f_lanes = f.dtype().get_lanes_or_vscale_factor();
      int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes);
      bool is_scalable = cond.dtype().is_scalable_vector() ||
                         t.dtype().is_scalable_vector() ||
                         f.dtype().is_scalable_vector();
      return Select(BroadcastTo(cond, lanes, is_scalable),
                    BroadcastTo(t, lanes, is_scalable),
                    BroadcastTo(f, lanes, is_scalable));
    }
  }

  PrimExpr VisitExpr_(const CastNode *op) final {
    PrimExpr value = this->VisitExpr(op->value);
    if (value.same_as(op->value)) {
384
      return tvm::ffi::GetRef<PrimExpr>(op);
385
386
387
388
389
390
391
392
393
394
395
396
    } else {
      if (value.dtype().is_scalable_vector()) {
        return Cast(op->dtype.with_scalable_vscale_factor(
                        value.dtype().vscale_factor()),
                    value);
      } else {
        return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
      }
    }
  }

  PrimExpr VisitExpr_(const FloatImmNode *op) final {
397
    return tvm::ffi::GetRef<PrimExpr>(op);
398
399
400
  }

  PrimExpr VisitExpr_(const IntImmNode *op) final {
401
    return tvm::ffi::GetRef<PrimExpr>(op);
402
403
404
  }

  PrimExpr VisitExpr_(const StringImmNode *op) final {
405
    return tvm::ffi::GetRef<PrimExpr>(op);
406
407
408
409
  }

  // Variable
  PrimExpr VisitExpr_(const VarNode *op) final {
410
    Var var = tvm::ffi::GetRef<Var>(op);
411
412
413
414

    if (var.same_as(var_)) {
      return ramp_;
    }
415
416
    auto it = let_var_map_.find(var);
    if (it != let_var_map_.end()) {
417
418
419
420
421
422
423
424
425
426
      return it->second;
    } else {
      return std::move(var);
    }
  }
  // IfThenElse expr
  PrimExpr MutateIfThenElseExpr_(const CallNode *op) {
    PrimExpr cond = this->VisitExpr(op->args[0]);
    if (cond.dtype().is_scalable_or_fixed_length_vector()) {
      need_scalarize_ = true;
427
      return tvm::ffi::GetRef<PrimExpr>(op);
428
429
430
431
432
    }
    PrimExpr t = this->VisitExpr(op->args[1]);
    PrimExpr f = this->VisitExpr(op->args[2]);
    if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
        f.same_as(op->args[2])) {
433
      return tvm::ffi::GetRef<PrimExpr>(op);
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
    } else {
      int t_lanes = t.dtype().get_lanes_or_vscale_factor();
      int f_lanes = f.dtype().get_lanes_or_vscale_factor();
      int lanes = std::max(t_lanes, f_lanes);
      bool is_scalable =
          t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector();
      t = BroadcastTo(t, lanes, is_scalable);
      f = BroadcastTo(f, lanes, is_scalable);
      if (is_scalable) {
        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
                    {cond, t, f});
      } else {
        return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
      }
    }
  }
  // Reinterpret expr
  PrimExpr MutateReinterpretExpr_(const CallNode *op) {
    ICHECK(op->op.same_as(builtin::reinterpret()));
    PrimExpr value = this->VisitExpr(op->args[0]);
    if (value.same_as(op->args[0])) {
455
      return tvm::ffi::GetRef<PrimExpr>(op);
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    } else {
      int lanes = value.dtype().get_lanes_or_vscale_factor();
      if (value.dtype().is_scalable_vector()) {
        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
                    {value});
      } else {
        return Call(op->dtype.with_lanes(lanes), op->op, {value});
      }
    }
  }
  // Call
  PrimExpr VisitExpr_(const CallNode *op) final {
    if (op->op.same_as(builtin::if_then_else())) {
      return MutateIfThenElseExpr_(op);
    } else if (op->op.same_as(builtin::texture2d_load())) {
      int lane = 0;
      Array<PrimExpr> fcd = MutateArray({op->args.back()}, &lane);
      auto new_args = op->args;
      new_args.pop_back();
      new_args.push_back(fcd[0]);
      return Call(op->dtype.with_lanes(4), op->op, new_args);
    } else if (op->op.same_as(builtin::texture2d_store())) {
      int lane = 0;
      // Vectorize the value to store
      Array<PrimExpr> value{op->args.back()};
      Array<PrimExpr> mutated_value = MutateArray(value, &lane);
      Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2],
                               mutated_value[0]};
      return Call(op->dtype.with_lanes(lane), op->op, new_args);
    } else if (op->op.same_as(builtin::reinterpret())) {
      return MutateReinterpretExpr_(op);
    }
    auto optional_op = op->op.as<Op>();
    bool vectorizable = optional_op &&
                        op_vectorizable_.get(optional_op.value(), false) &&
                        !op->dtype.is_scalable_vector();
    if (!vectorizable) {
      // Cannot vectorize this op
      Array<PrimExpr> new_args;
      for (auto arg : op->args) {
        auto new_arg = this->VisitExpr(arg);
        if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
          need_scalarize_ = true;
499
          return tvm::ffi::GetRef<PrimExpr>(op);
500
501
502
503
        }
        new_args.push_back(new_arg);
      }
      if (op->args.same_as(new_args)) {
504
        return tvm::ffi::GetRef<PrimExpr>(op);
505
506
507
508
509
510
511
512
      } else {
        return Call(op->dtype, op->op, new_args);
      }
    } else {
      int lane = 0;
      Array<PrimExpr> new_args = MutateArray(op->args, &lane);
      // normal code path.
      if (op->args.same_as(new_args)) {
513
        return tvm::ffi::GetRef<PrimExpr>(op);
514
515
516
517
518
519
520
      } else {
        return Call(op->dtype.with_lanes(lane), op->op, new_args);
      }
    }
  }
  // BufferLoad
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
521
    auto load = tvm::ffi::GetRef<BufferLoad>(op);
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542

    auto fmutate = [this](const PrimExpr &index) {
      return this->VisitExpr(index);
    };
    Array<PrimExpr> indices = op->indices.Map(fmutate);

    if (!indices.same_as(op->indices)) {
      BufferLoadNode *writer = load.CopyOnWrite();
      writer->indices = indices;
      LegalizeBufferLoadDType(writer);
    }

    return std::move(load);
  }
  // Let
  PrimExpr VisitExpr_(const LetNode *op) final {
    PrimExpr value = this->VisitExpr(op->value);
    // Weaker SSA condition
    // A single var can be binded in multiple lets
    // but they have to bind to the same value.
    // This is used to allow cases when we reuse a single let
Gabriel Wu's avatar
Gabriel Wu committed
543
    // expression to construct a nested expr.
544
    // (let x = 1 in x + 1) * (let x = 1 in x + 1)
545
546
    auto it = let_var_map_.find(op->var);
    if (it != let_var_map_.end()) {
547
548
549
550
551
552
      ICHECK(deep_equal_(it->second, value))
          << "Let cannot bind the same var to two different values";
    }
    if (value.dtype().get_lanes_or_vscale_factor() !=
        op->value.dtype().get_lanes_or_vscale_factor()) {
      Var new_var(op->var->name_hint, value.dtype());
553
554
555
      let_var_map_[op->var] = new_var;
      // Record mapping from the new var to its bound value
      let_value_binding_[new_var] = value;
556
557
      return Let(new_var, value, this->VisitExpr(op->body));
    } else {
558
      let_var_map_[op->var] = op->var;
559
560
      PrimExpr body = this->VisitExpr(op->body);
      if (value.same_as(op->value) && body.same_as(op->body)) {
561
        return tvm::ffi::GetRef<PrimExpr>(op);
562
563
564
565
566
567
568
      } else {
        return Let(op->var, value, body);
      }
    }
  }
  // BufferStore
  Stmt VisitStmt_(const BufferStoreNode *op) final {
569
    auto store = tvm::ffi::GetRef<BufferStore>(op);
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631

    auto fmutate = [this](const PrimExpr &index) {
      return this->VisitExpr(index);
    };
    Array<PrimExpr> indices = op->indices.Map(fmutate);

    PrimExpr value = this->VisitExpr(op->value);

    if (!indices.same_as(op->indices) || !value.same_as(op->value)) {
      ICHECK(!op->buffer->dtype.is_scalable_vector())
          << "Vectorizing over scalable buffer elements is not supported in "
             "vectorizer.";
      // How many lanes of indexing are present in the index and
      // buffer element type, excluding the last index.
      int other_index_lanes = op->buffer->dtype.lanes();
      for (size_t i = 0; i < indices.size() - 1; i++) {
        other_index_lanes *= indices[i].dtype().lanes();
        // Only allow the last index to be scalable
        ICHECK(!indices[i].dtype().is_scalable_vector())
            << "Only the last index can be scalable.";
      }

      // The total number of lanes of indexing, including the last index.
      auto last_index_dtype = indices[indices.size() - 1].dtype();
      int lanes_in_last_index = last_index_dtype.get_lanes_or_vscale_factor();
      int index_lanes = other_index_lanes * lanes_in_last_index;

      // The total number of lanes in this store operation.  Either
      // the index or the value will be broadcast out to this number
      // of lanes, depending on which has more lanes.
      int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor();
      bool is_last_index_scalable = last_index_dtype.is_scalable_vector();
      int total_lanes = std::max(index_lanes, value_dtype_lanes);

      ICHECK_EQ(total_lanes % other_index_lanes, 0)
          << "When storing to buffer " << op->buffer->name
          << ", cannot produce " << total_lanes
          << " lanes of storage location by changing the last index.";
      int last_index_lanes = total_lanes / other_index_lanes;

      // Broadcast the last index such that the total number of index
      // lanes matches the desired number.
      indices.Set(indices.size() - 1,
                  BroadcastTo(indices[indices.size() - 1], last_index_lanes,
                              is_last_index_scalable));

      auto writer = store.CopyOnWrite();
      writer->indices = indices;
      writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable);
    }

    return std::move(store);
  }
  // For
  Stmt VisitStmt_(const ForNode *op) final {
    if (op->kind == ForKind::kVectorized) {
      LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
    }
    ICHECK(is_zero(op->min));
    ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
    PrimExpr extent = this->VisitExpr(op->extent);
    if (extent.dtype().is_scalable_or_fixed_length_vector()) {
632
      return Scalarize(tvm::ffi::GetRef<Stmt>(op));
633
634
635
    }
    Stmt body = this->VisitStmt(op->body);
    if (extent.same_as(op->extent) && body.same_as(op->body)) {
636
      return tvm::ffi::GetRef<Stmt>(op);
637
638
639
640
641
642
643
644
645
646
    } else {
      return For(op->loop_var, op->min, extent, op->kind, body,
                 op->thread_binding, op->annotations);
    }
  }
  // IfThenElse
  Stmt VisitStmt_(const IfThenElseNode *op) final {
    ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
    PrimExpr condition = this->VisitExpr(op->condition);
    if (condition.dtype().is_scalable_or_fixed_length_vector()) {
647
      return Scalarize(tvm::ffi::GetRef<Stmt>(op));
648
649
    }
    Stmt then_case = this->VisitStmt(op->then_case);
650
    Optional<Stmt> else_case = std::nullopt;
651
652
653
654
655
    if (op->else_case) {
      else_case = this->VisitStmt(op->else_case.value());
    }
    if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
        else_case.same_as(op->else_case)) {
656
      return tvm::ffi::GetRef<Stmt>(op);
657
658
659
660
661
662
663
664
665
666
667
    } else {
      return IfThenElse(condition, then_case, else_case);
    }
  }
  // While
  Stmt VisitStmt_(const WhileNode *op) final {
    LOG(FATAL) << "A while loop inside a vectorized loop not supported.";
  }
  // LetStmt
  Stmt VisitStmt_(const LetStmtNode *op) final {
    PrimExpr value = this->VisitExpr(op->value);
668
    ICHECK(!let_var_map_.count(op->var))
669
670
671
672
        << "SSA violation, a single var is binded twice";
    if (value.dtype().get_lanes_or_vscale_factor() !=
        op->value.dtype().get_lanes_or_vscale_factor()) {
      Var new_var(op->var->name_hint, value.dtype());
673
674
675
676
677
      let_var_map_[op->var] = new_var;
      // Record mapping from the new var to its bound value
      let_value_binding_[op->var] = op->value;
      let_value_binding_[new_var] = value;

678
679
      return LetStmt(new_var, value, this->VisitStmt(op->body));
    } else {
680
681
      let_var_map_[op->var] = op->var;
      let_value_binding_[op->var] = value;
682
683
      Stmt body = this->VisitStmt(op->body);
      if (value.same_as(op->value) && body.same_as(op->body)) {
684
        return tvm::ffi::GetRef<Stmt>(op);
685
686
687
688
689
      } else {
        return LetStmt(op->var, value, body);
      }
    }
  }
690

691
692
693
694
695
696
697
  // Allocate
  Stmt VisitStmt_(const AllocateNode *op) final {
    // Mutate the condition
    PrimExpr condition = this->VisitExpr(op->condition);
    if (condition.dtype().is_scalable_or_fixed_length_vector()) {
      LOG(WARNING) << "Cannot handle vector extent in alloc of "
                   << op->buffer_var->name_hint;
698
      return Scalarize(tvm::ffi::GetRef<Stmt>(op));
699
700
    }

701
    return StmtMutator::VisitStmt_(op);
702
703
  }

Gabriel Wu's avatar
Gabriel Wu committed
704
  // scalarize the statement
705
  Stmt Scalarize(Stmt stmt) {
706
707
708
709
710
711
712
713
714
715
716
    Var idx(var_->name_hint + "_s", var_->dtype);
    // Find all Vars in stmt that are keys in let_value_binding_
    std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_let_bound_vars;
    PostOrderVisit(stmt, [this, &used_let_bound_vars](const ObjectRef &node) {
      if (const auto *v = node.as<VarNode>()) {
        Var var = GetRef<Var>(v);
        if (let_value_binding_.count(var)) {
          used_let_bound_vars.insert(var);
        }
      }
    });
717
    stmt = Substitute(stmt, {{var_, idx}});
718
719
720
721
722
723
724
725
726

    if (!used_let_bound_vars.empty()) {
      for (const auto &v : used_let_bound_vars) {
        // Bind the existing var v to its value around the stmt scope
        auto new_value = Substitute(let_value_binding_.at(v), {{var_, idx}});
        stmt = LetStmt(v, new_value, stmt);
      }
    }

727
728
729
730
731
732
733
734
735
736
737
738
739
740
    return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
  }

private:
  // analyzer
  arith::Analyzer analyzer_;
  // deep equal
  ExprDeepEqual deep_equal_;
  // variable to be replaced
  Var var_;
  // the lanes.
  PrimExpr var_lanes_;
  // ramp representing the var.
  PrimExpr ramp_;
Gabriel Wu's avatar
Gabriel Wu committed
741
  // flag to mark requirement of scalarization.
742
  bool need_scalarize_{false};
743
744
745
746
747
  // Let var mapping
  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_var_map_;
  // Let value binding: map new_var -> value
  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>
      let_value_binding_;
748
749
750
751
752
753
754
  // vectorizable property
  OpAttrMap<TVectorizable> op_vectorizable_ =
      Op::GetAttrMap<TVectorizable>("TVectorizable");

  // mutate array, with given lane requirement
  // when finished, p_lane updates the lane requirement.
  Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int *p_lanes) {
755
    if (arr.empty())
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
      return arr;
    int &lanes = *p_lanes;
    bool changed = false;
    std::vector<PrimExpr> new_arr(arr.size());
    for (size_t i = 0; i < arr.size(); i++) {
      PrimExpr old_elem = arr[i];
      PrimExpr new_elem = this->VisitExpr(old_elem);
      if (!new_elem.same_as(old_elem))
        changed = true;
      new_arr[i] = new_elem;
      lanes = std::max(lanes, new_elem.dtype().lanes());
    }

    for (size_t i = 0; i < arr.size(); ++i) {
      if (new_arr[i].dtype().lanes() != lanes) {
        new_arr[i] = BroadcastTo(new_arr[i], lanes, false);
        changed = true;
      }
    }
    if (!changed)
      return arr;
    return Array<PrimExpr>(new_arr);
  }
  template <typename TOp, typename T> PrimExpr BinaryVec(const T *op) {
    static_assert(std::is_same<typename TOp::ContainerType, T>::value,
                  "constraint");
    PrimExpr a = this->VisitExpr(op->a);
    PrimExpr b = this->VisitExpr(op->b);
    if (a.same_as(op->a) && b.same_as(op->b)) {
785
      return tvm::ffi::GetRef<PrimExpr>(op);
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
    } else {
      int a_lanes = a.dtype().get_lanes_or_vscale_factor();
      int b_lanes = b.dtype().get_lanes_or_vscale_factor();
      int lanes = std::max(a_lanes, b_lanes);
      bool is_scalable =
          a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
      return TOp(BroadcastTo(a, lanes, is_scalable),
                 BroadcastTo(b, lanes, is_scalable));
    }
  }
  template <typename T, typename FCompute>
  PrimExpr AddSubVec(const T *op, FCompute fcompute) {
    PrimExpr a = this->VisitExpr(op->a);
    PrimExpr b = this->VisitExpr(op->b);
    if (a.same_as(op->a) && b.same_as(op->b)) {
801
      return tvm::ffi::GetRef<PrimExpr>(op);
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
    } else {
      int a_lanes = a.dtype().get_lanes_or_vscale_factor();
      int b_lanes = b.dtype().get_lanes_or_vscale_factor();
      int lanes = std::max(a_lanes, b_lanes);
      if (lanes != 1) {
        const RampNode *b_ramp = b.as<RampNode>();
        const RampNode *a_ramp = a.as<RampNode>();
        if (a.dtype().is_scalar() && b_ramp) {
          return Ramp(
              fcompute(a, b_ramp->base),
              fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride),
              b_ramp->lanes);
        }
        if (b.dtype().is_scalar() && a_ramp) {
          return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
        }
      }
      bool is_scalable =
          a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
      return fcompute(BroadcastTo(a, lanes, is_scalable),
                      BroadcastTo(b, lanes, is_scalable));
    }
  }
};

827
828
829
830
inline bool TargetHasSVE() {
  return Target::Current()->GetFeature<Bool>("has_sve").value_or(false);
}

831
832
833
834
835
836
837
838
839
class LoopVectorizer : public StmtMutator {
public:
  Stmt VisitStmt_(const ForNode *op) final {
    if (op->kind == ForKind::kVectorized) {
      auto *extent_as_int = op->extent.as<IntImmNode>();

      if (!extent_as_int || extent_as_int->value < 1) {
        bool is_scalable_expr =
            CheckContains::ExprContains(op->extent, arith::IsVScaleCall);
840
        ICHECK(is_scalable_expr && TargetHasSVE())
841
842
843
844
            << "Failed to vectorize loop with extent " << op->extent
            << " for target " << Target::Current();
      }
      ICHECK(is_zero(op->min));
845
      return TLVectorizer::Vectorize(op->loop_var, op->extent, op->body);
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
    } else {
      return StmtMutator::VisitStmt_(op);
    }
  }
};

class VectorizeSkipper : public StmtMutator {
public:
  Stmt VisitStmt_(const ForNode *op) final {
    Stmt stmt = StmtMutator::VisitStmt_(op);
    op = stmt.as<ForNode>();
    if (op->kind == ForKind::kVectorized) {
      return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body);
    } else {
      return stmt;
    }
  }
};

Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); }

tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) {
  using namespace tir::transform;
869
  auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
870
871
872
873
874
875
876
877
878
879
880
    auto *n = f.CopyOnWrite();
    if (enable_vectorize) {
      n->body = tvm::tl::LoopVectorizer()(std::move(n->body));
    } else {
      n->body = tvm::tl::VectorizeSkipper()(std::move(n->body));
    }
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {});
}

881
TVM_FFI_STATIC_INIT_BLOCK() {
882
883
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.VectorizeLoop", VectorizeLoop);
884
}
885
886
887

} // namespace tl
} // namespace tvm