loop_vectorization_utils.h 28.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file common.h
 * \brief Common utilities for TL transforms
 */

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>

#include <queue>
32
#include <utility>
33
34
35
36

#include "../../op/parallel.h"
#include "../loop_partition.h"
#include "../loop_vectorize.h"
37
#include "arith/ir_mutator_with_analyzer.h"
38
39
40
41
42
43
44
45
46
47

namespace tvm {
namespace tl {

using namespace tir;

// Vectorize Part
// Use the same code as tir.transform.vectorize_loop
inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
  if (is_scalable) {
48
49
    return Mul(Call(DataType::Int(32), builtin::vscale(), {}),
               lanes_or_vscale_factor);
50
51
52
53
54
55
56
57
58
59
60
  } else {
    return lanes_or_vscale_factor;
  }
}

inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
  // Check if e is already in the expected form
  if (e.dtype().get_lanes_or_vscale_factor() == lanes &&
      e.dtype().is_scalable_vector() == is_scalable)
    return e;

61
  if (const BroadcastNode *op = e.as<BroadcastNode>()) {
62
63
64
65
66
67
68
69
70
    ICHECK(op->dtype.is_scalable_vector() == is_scalable)
        << "Can't broadcast between scalable and fixed length vectors.";
    int e_lanes = op->dtype.get_lanes_or_vscale_factor();

    if (lanes % e_lanes == 0) {
      return Broadcast(op->value, CreateNewLanes(is_scalable, lanes));
    }
  }

71
72
73
  ICHECK(e.dtype().is_scalar())
      << "Cannot broadcast lanes=" << e.dtype().get_lanes_or_vscale_factor()
      << " is_scalable=" << e.dtype().is_scalable_vector() << " to " << lanes;
74
75
76
77
78

  return Broadcast(e, CreateNewLanes(is_scalable, lanes));
}

// Rewrite vectorized allocation access
79
80
// This is necessary for making each vector component containing its own
// workspace. Originates from Halide's loop vectorizer
81
82
83
//
// s[i] = s[i * lanes + var]
//
84
85
// The same principle applies when using one thread to simulate multiple
// context.
86
87
//
class VecAllocAccess : public StmtExprMutator {
88
89
public:
  VecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes)
90
      : buf_(buf), var_(std::move(var)), var_lanes_(std::move(var_lanes)) {}
91

92
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
93
94
95
96
    auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
    return UpdateBufferAccess(load);
  }

97
  Stmt VisitStmt_(const BufferStoreNode *op) final {
98
99
100
101
    auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
    return UpdateBufferAccess(store);
  }

102
103
private:
  template <typename Node> Node UpdateBufferAccess(Node node) {
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    // Only update the buffer that's being replaced.
    if (node->buffer->data.get() != buf_) {
      return node;
    }

    // Find/make a Buffer object with the correct updated shape.
    Buffer buf;
    auto it = buffer_map_.find(node->buffer.get());
    if (it != buffer_map_.end()) {
      buf = it->second;
    } else {
      // Extend the least significant dimension by a factor of
      // var_lanes_.  Typically, this will be a 1-d index into a flat
      // memory space.
      Array<PrimExpr> shape = node->buffer->shape;
119
120
      shape.Set(shape.size() - 1,
                analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_));
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

      // TODO(Lunderberg): Move this pass to be prior to
      // StorageFlatten/FlattenBuffer, implement by appending a
      // dimension to the buffer.  Since it is currently after the
      // flattening, the strides are not technically necessary, but
      // are updated for consistency.

      // Update strides if defined.
      Array<PrimExpr> strides;
      for (size_t i = 0; i < strides.size(); i++) {
        PrimExpr stride = strides[i];
        if (i != strides.size() - 1) {
          stride *= var_lanes_;
        }
        strides.push_back(analyzer_.Simplify(stride));
      }

      // Copy everything into the new buffer.
      buf = node->buffer;
      auto buf_writer = buf.CopyOnWrite();
      buf_writer->shape = shape;
      buf_writer->strides = strides;
      buffer_map_[buf.get()] = buf;
    }

    // Extend the last index by the number of lanes in the vectorized
    // variable.
    Array<PrimExpr> indices = node->indices;
149
150
151
    indices.Set(
        indices.size() - 1,
        analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_));
152
153
154
155
156
157
158
159

    auto writer = node.CopyOnWrite();
    writer->buffer = buf;
    writer->indices = indices;
    return node;
  }

  // buffer var
160
  const VarNode *buf_;
161
  // Updated buffer objects.
162
  std::unordered_map<const BufferNode *, Buffer> buffer_map_;
163
164
165
166
167
168
169
170
171
172
173
  // variable to be replaced
  Var var_;
  // the lanes.
  PrimExpr var_lanes_;
  // Analyzer for simplifications
  arith::Analyzer analyzer_;
};

// We use ExprFunctor directly instead of StmtExprMutator
// This is because the transformation can change the dtype of the Expr
// The existing ExprMutator transformation rules may not be well defined.
174
175
176
class Vectorizer : public StmtMutator,
                   public ExprFunctor<PrimExpr(const PrimExpr &)> {
public:
177
178
179
  using ExprFunctor::VisitExpr;
  using StmtMutator::operator();

180
181
  Vectorizer(const Var &var, const PrimExpr &var_lanes)
      : var_(var), var_lanes_(var_lanes) {
182
183
184
    ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
  }

185
  Stmt VisitStmt(const Stmt &stmt) final {
186
187
188
189
190
191
192
193
194
195
    ICHECK(!need_scalarize_);
    Stmt ret = StmtMutator::VisitStmt(stmt);
    if (need_scalarize_) {
      need_scalarize_ = false;
      return Scalarize(stmt);
    } else {
      return ret;
    }
  }

196
197
198
  PrimExpr VisitExpr(const PrimExpr &e) final {
    return ExprFunctor::VisitExpr(e);
  }
199

200
  PrimExpr VisitExpr_(const AddNode *op) final {
201
202
    return AddSubVec(
        op, [](PrimExpr a, PrimExpr b) { return std::move(a) + std::move(b); });
203
204
  }

205
  PrimExpr VisitExpr_(const SubNode *op) final {
206
207
    return AddSubVec(
        op, [](PrimExpr a, PrimExpr b) { return std::move(a) - std::move(b); });
208
209
  }

210
  PrimExpr VisitExpr_(const MulNode *op) final {
211
212
213
    PrimExpr a = this->VisitExpr(op->a);
    PrimExpr b = this->VisitExpr(op->b);
    if (a.same_as(op->a) && b.same_as(op->b)) {
214
      return tvm::ffi::GetRef<PrimExpr>(op);
215
216
217
218
219
220
    } else {
      bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
      bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
      if (is_vec_a && is_vec_b) {
        // Let's not multiply scalable and fixed length vectors
        ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector())
221
222
            << "Fixed length and scalable vectors can't be mixed in "
               "multiplication.";
223
224
      }
      if (is_vec_a || is_vec_b) {
225
226
        const RampNode *b_ramp = b.as<RampNode>();
        const RampNode *a_ramp = a.as<RampNode>();
227
228
229
230
231
232
233
234
235
236
237
        if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) {
          PrimExpr lanes = a_ramp->lanes;
          return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes);
        }
        if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) {
          PrimExpr lanes = b_ramp->lanes;
          return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes);
        }
        int a_lanes = a.dtype().get_lanes_or_vscale_factor();
        int b_lanes = b.dtype().get_lanes_or_vscale_factor();
        int max_lanes = std::max(a_lanes, b_lanes);
238
239
240
241
        bool is_scalable =
            a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
        return Mul(BroadcastTo(a, max_lanes, is_scalable),
                   BroadcastTo(b, max_lanes, is_scalable));
242
243
244
245
      }
    }
    return BinaryVec<Mul>(op);
  }
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
  PrimExpr VisitExpr_(const DivNode *op) final { return BinaryVec<Div>(op); }
  PrimExpr VisitExpr_(const ModNode *op) final { return BinaryVec<Mod>(op); }
  PrimExpr VisitExpr_(const FloorDivNode *op) final {
    return BinaryVec<FloorDiv>(op);
  }
  PrimExpr VisitExpr_(const FloorModNode *op) final {
    return BinaryVec<FloorMod>(op);
  }
  PrimExpr VisitExpr_(const MinNode *op) final { return BinaryVec<Min>(op); }
  PrimExpr VisitExpr_(const MaxNode *op) final { return BinaryVec<Max>(op); }
  PrimExpr VisitExpr_(const EQNode *op) final { return BinaryVec<EQ>(op); }
  PrimExpr VisitExpr_(const NENode *op) final { return BinaryVec<NE>(op); }
  PrimExpr VisitExpr_(const LTNode *op) final { return BinaryVec<LT>(op); }
  PrimExpr VisitExpr_(const LENode *op) final { return BinaryVec<LE>(op); }
  PrimExpr VisitExpr_(const GTNode *op) final { return BinaryVec<GT>(op); }
  PrimExpr VisitExpr_(const GENode *op) final { return BinaryVec<GE>(op); }
  PrimExpr VisitExpr_(const AndNode *op) final { return BinaryVec<And>(op); }
  PrimExpr VisitExpr_(const OrNode *op) final { return BinaryVec<Or>(op); }

  PrimExpr VisitExpr_(const NotNode *op) final {
266
267
    PrimExpr a = this->VisitExpr(op->a);
    if (a.same_as(op->a)) {
268
      return tvm::ffi::GetRef<PrimExpr>(op);
269
270
271
272
273
    } else {
      return !(a);
    }
  }

274
  PrimExpr VisitExpr_(const RampNode *op) final {
275
276
277
278
279
280
281
282
283
    PrimExpr base = this->VisitExpr(op->base);
    PrimExpr stride = this->VisitExpr(op->stride);
    ICHECK(!base.dtype().is_scalable_vector())
        << "Creating scalable vectors from existing vectors is not supported.";
    ICHECK(!stride.dtype().is_scalable_vector())
        << "Ramp stride with scalable dtype is not supported";
    if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) {
      ICHECK(op->lanes->IsInstance<IntImmNode>())
          << "Vectorizing over existing scalable vectors is not supported.";
284
      const RampNode *base_ramp = base.as<RampNode>();
285
      int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
286
287
      int base_ramp_lanes =
          static_cast<int>(Downcast<IntImm>(base_ramp->lanes)->value);
288
      if (analyzer_.CanProve(base_ramp->stride ==
289
290
                             stride *
                                 make_const(stride.dtype(), base_ramp_lanes))) {
291
292
293
294
295
296
297
298
        return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes);
      }
    }
    int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
    base = BroadcastTo(base, lanes, false);
    stride = BroadcastTo(stride, lanes, false);
    Array<PrimExpr> elems;
    for (int i = 0; i < lanes; ++i) {
299
300
      elems.push_back(Ramp(Shuffle::ExtractElement(base, i),
                           Shuffle::ExtractElement(stride, i), op->lanes));
301
302
303
304
    }
    return Shuffle::Concat(elems);
  }

305
  PrimExpr VisitExpr_(const BroadcastNode *op) final {
306
307
308
    PrimExpr value = this->VisitExpr(op->value);
    if (value.dtype().is_scalable_or_fixed_length_vector()) {
      need_scalarize_ = true;
309
      return tvm::ffi::GetRef<PrimExpr>(op);
310
311
    }
    if (value.same_as(op->value)) {
312
      return tvm::ffi::GetRef<PrimExpr>(op);
313
314
315
316
317
    } else {
      return Broadcast(op->value, op->lanes);
    }
  }

318
  PrimExpr VisitExpr_(const SelectNode *op) final {
319
320
321
    PrimExpr cond = this->VisitExpr(op->condition);
    PrimExpr t = this->VisitExpr(op->true_value);
    PrimExpr f = this->VisitExpr(op->false_value);
322
323
    if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
        f.same_as(op->false_value)) {
324
      return tvm::ffi::GetRef<PrimExpr>(op);
325
326
327
328
329
    } else {
      int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
      int t_lanes = t.dtype().get_lanes_or_vscale_factor();
      int f_lanes = f.dtype().get_lanes_or_vscale_factor();
      int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes);
330
331
      bool is_scalable = cond.dtype().is_scalable_vector() ||
                         t.dtype().is_scalable_vector() ||
332
                         f.dtype().is_scalable_vector();
333
334
      return Select(BroadcastTo(cond, lanes, is_scalable),
                    BroadcastTo(t, lanes, is_scalable),
335
336
337
338
                    BroadcastTo(f, lanes, is_scalable));
    }
  }

339
  PrimExpr VisitExpr_(const CastNode *op) final {
340
341
    PrimExpr value = this->VisitExpr(op->value);
    if (value.same_as(op->value)) {
342
      return tvm::ffi::GetRef<PrimExpr>(op);
343
344
    } else {
      if (value.dtype().is_scalable_vector()) {
345
346
347
        return Cast(op->dtype.with_scalable_vscale_factor(
                        value.dtype().vscale_factor()),
                    value);
348
349
350
351
352
353
      } else {
        return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
      }
    }
  }

354
  PrimExpr VisitExpr_(const FloatImmNode *op) final {
355
    return tvm::ffi::GetRef<PrimExpr>(op);
356
  }
357

358
  PrimExpr VisitExpr_(const IntImmNode *op) final {
359
    return tvm::ffi::GetRef<PrimExpr>(op);
360
  }
361

362
  PrimExpr VisitExpr_(const StringImmNode *op) final {
363
    return tvm::ffi::GetRef<PrimExpr>(op);
364
  }
365
366

  // Variable
367
  PrimExpr VisitExpr_(const VarNode *op) final {
368
    Var var = tvm::ffi::GetRef<Var>(op);
369
370
371
372
373
374
375
376
377
378
379
380

    if (var.same_as(var_)) {
      return ramp_;
    }
    auto it = let_binding_.find(var);
    if (it != let_binding_.end()) {
      return it->second;
    } else {
      return std::move(var);
    }
  }
  // IfThenElse expr
381
  PrimExpr MutateIfThenElseExpr_(const CallNode *op) {
382
383
384
    PrimExpr cond = this->VisitExpr(op->args[0]);
    if (cond.dtype().is_scalable_or_fixed_length_vector()) {
      need_scalarize_ = true;
385
      return tvm::ffi::GetRef<PrimExpr>(op);
386
387
388
    }
    PrimExpr t = this->VisitExpr(op->args[1]);
    PrimExpr f = this->VisitExpr(op->args[2]);
389
390
    if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
        f.same_as(op->args[2])) {
391
      return tvm::ffi::GetRef<PrimExpr>(op);
392
393
394
395
    } else {
      int t_lanes = t.dtype().get_lanes_or_vscale_factor();
      int f_lanes = f.dtype().get_lanes_or_vscale_factor();
      int lanes = std::max(t_lanes, f_lanes);
396
397
      bool is_scalable =
          t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector();
398
399
400
      t = BroadcastTo(t, lanes, is_scalable);
      f = BroadcastTo(f, lanes, is_scalable);
      if (is_scalable) {
401
402
        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
                    {cond, t, f});
403
404
405
406
407
408
      } else {
        return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
      }
    }
  }
  // Reinterpret expr
409
  PrimExpr MutateReinterpretExpr_(const CallNode *op) {
410
411
412
    ICHECK(op->op.same_as(builtin::reinterpret()));
    PrimExpr value = this->VisitExpr(op->args[0]);
    if (value.same_as(op->args[0])) {
413
      return tvm::ffi::GetRef<PrimExpr>(op);
414
415
416
    } else {
      int lanes = value.dtype().get_lanes_or_vscale_factor();
      if (value.dtype().is_scalable_vector()) {
417
418
        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
                    {value});
419
420
421
422
423
424
      } else {
        return Call(op->dtype.with_lanes(lanes), op->op, {value});
      }
    }
  }
  // Call
425
  PrimExpr VisitExpr_(const CallNode *op) final {
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    if (op->op.same_as(builtin::if_then_else())) {
      return MutateIfThenElseExpr_(op);
    } else if (op->op.same_as(builtin::texture2d_load())) {
      int lane = 0;
      Array<PrimExpr> fcd = MutateArray({op->args.back()}, &lane);
      auto new_args = op->args;
      new_args.pop_back();
      new_args.push_back(fcd[0]);
      return Call(op->dtype.with_lanes(4), op->op, new_args);
    } else if (op->op.same_as(builtin::texture2d_store())) {
      int lane = 0;
      // Vectorize the value to store
      Array<PrimExpr> value{op->args.back()};
      Array<PrimExpr> mutated_value = MutateArray(value, &lane);
440
441
      Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2],
                               mutated_value[0]};
442
443
444
445
446
      return Call(op->dtype.with_lanes(lane), op->op, new_args);
    } else if (op->op.same_as(builtin::reinterpret())) {
      return MutateReinterpretExpr_(op);
    }
    auto optional_op = op->op.as<Op>();
447
448
    bool vectorizable = optional_op &&
                        op_vectorizable_.get(optional_op.value(), false) &&
449
450
451
452
453
454
455
456
457
                        !op->dtype.is_scalable_vector();

    if (!vectorizable) {
      // Cannot vectorize this op
      Array<PrimExpr> new_args;
      for (auto arg : op->args) {
        auto new_arg = this->VisitExpr(arg);
        if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
          need_scalarize_ = true;
458
          return tvm::ffi::GetRef<PrimExpr>(op);
459
460
461
462
        }
        new_args.push_back(new_arg);
      }
      if (op->args.same_as(new_args)) {
463
        return tvm::ffi::GetRef<PrimExpr>(op);
464
465
466
467
468
469
470
471
      } else {
        return Call(op->dtype, op->op, new_args);
      }
    } else {
      int lane = 0;
      Array<PrimExpr> new_args = MutateArray(op->args, &lane);
      // normal code path.
      if (op->args.same_as(new_args)) {
472
        return tvm::ffi::GetRef<PrimExpr>(op);
473
474
475
476
477
478
      } else {
        return Call(op->dtype.with_lanes(lane), op->op, new_args);
      }
    }
  }
  // BufferLoad
479
  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
480
    auto load = tvm::ffi::GetRef<BufferLoad>(op);
481

482
483
484
    auto fmutate = [this](const PrimExpr &index) {
      return this->VisitExpr(index);
    };
485
486
487
488
489
490
491
492
493
494
    Array<PrimExpr> indices = op->indices.Map(fmutate);

    if (!indices.same_as(op->indices)) {
      auto writer = load.CopyOnWrite();
      writer->indices = indices;
    }

    return std::move(load);
  }
  // Let
495
  PrimExpr VisitExpr_(const LetNode *op) final {
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    PrimExpr value = this->VisitExpr(op->value);
    // Weaker SSA condition
    // A single var can be binded in multiple lets
    // but they have to bind to the same value.
    // This is used to allow cases when we reuse a single let
    // expression to construct a nested expr.
    // (let x = 1 in x + 1) * (let x = 1 in x + 1)
    auto it = let_binding_.find(op->var);
    if (it != let_binding_.end()) {
      ICHECK(deep_equal_(it->second, value))
          << "Let cannot bind the same var to two different values";
    }
    if (value.dtype().get_lanes_or_vscale_factor() !=
        op->value.dtype().get_lanes_or_vscale_factor()) {
      Var new_var(op->var->name_hint, value.dtype());
      let_binding_[op->var] = new_var;
      return Let(new_var, value, this->VisitExpr(op->body));
    } else {
      let_binding_[op->var] = op->var;
      PrimExpr body = this->VisitExpr(op->body);
      if (value.same_as(op->value) && body.same_as(op->body)) {
517
        return tvm::ffi::GetRef<PrimExpr>(op);
518
519
520
521
522
523
      } else {
        return Let(op->var, value, body);
      }
    }
  }
  // BufferStore
524
  Stmt VisitStmt_(const BufferStoreNode *op) final {
525
    auto store = tvm::ffi::GetRef<BufferStore>(op);
526

527
528
529
    auto fmutate = [this](const PrimExpr &index) {
      return this->VisitExpr(index);
    };
530
531
532
533
534
535
    Array<PrimExpr> indices = op->indices.Map(fmutate);

    PrimExpr value = this->VisitExpr(op->value);

    if (!indices.same_as(op->indices) || !value.same_as(op->value)) {
      ICHECK(!op->buffer->dtype.is_scalable_vector())
536
537
          << "Vectorizing over scalable buffer elements is not supported in "
             "vectorizer.";
538
539
540
541
542
543
      // How many lanes of indexing are present in the index and
      // buffer element type, excluding the last index.
      int other_index_lanes = op->buffer->dtype.lanes();
      for (size_t i = 0; i < indices.size() - 1; i++) {
        other_index_lanes *= indices[i].dtype().lanes();
        // Only allow the last index to be scalable
544
545
        ICHECK(!indices[i].dtype().is_scalable_vector())
            << "Only the last index can be scalable.";
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
      }

      // The total number of lanes of indexing, including the last index.
      auto last_index_dtype = indices[indices.size() - 1].dtype();
      int lanes_in_last_index = last_index_dtype.get_lanes_or_vscale_factor();
      int index_lanes = other_index_lanes * lanes_in_last_index;

      // The total number of lanes in this store operation.  Either
      // the index or the value will be broadcast out to this number
      // of lanes, depending on which has more lanes.
      int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor();
      bool is_last_index_scalable = last_index_dtype.is_scalable_vector();
      int total_lanes = std::max(index_lanes, value_dtype_lanes);

      ICHECK_EQ(total_lanes % other_index_lanes, 0)
561
562
          << "When storing to buffer " << op->buffer->name
          << ", cannot produce " << total_lanes
563
564
565
566
567
          << " lanes of storage location by changing the last index.";
      int last_index_lanes = total_lanes / other_index_lanes;

      // Broadcast the last index such that the total number of index
      // lanes matches the desired number.
568
569
570
      indices.Set(indices.size() - 1,
                  BroadcastTo(indices[indices.size() - 1], last_index_lanes,
                              is_last_index_scalable));
571
572
573
574
575
576
577
578
579

      auto writer = store.CopyOnWrite();
      writer->indices = indices;
      writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable);
    }

    return std::move(store);
  }
  // For
580
  Stmt VisitStmt_(const ForNode *op) final {
581
582
583
584
585
586
587
    if (op->kind == ForKind::kVectorized) {
      LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
    }
    ICHECK(is_zero(op->min));
    ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
    PrimExpr extent = this->VisitExpr(op->extent);
    if (extent.dtype().is_scalable_or_fixed_length_vector()) {
588
      return Scalarize(tvm::ffi::GetRef<Stmt>(op));
589
590
591
    }
    Stmt body = this->VisitStmt(op->body);
    if (extent.same_as(op->extent) && body.same_as(op->body)) {
592
      return tvm::ffi::GetRef<Stmt>(op);
593
    } else {
594
595
      return For(op->loop_var, op->min, extent, op->kind, body,
                 op->thread_binding, op->annotations);
596
597
598
    }
  }
  // IfThenElse
599
  Stmt VisitStmt_(const IfThenElseNode *op) final {
600
601
602
    ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
    PrimExpr condition = this->VisitExpr(op->condition);
    if (condition.dtype().is_scalable_or_fixed_length_vector()) {
603
      return Scalarize(tvm::ffi::GetRef<Stmt>(op));
604
605
    }
    Stmt then_case = this->VisitStmt(op->then_case);
606
    Optional<Stmt> else_case = std::nullopt;
607
608
609
610
611
    if (op->else_case) {
      else_case = this->VisitStmt(op->else_case.value());
    }
    if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
        else_case.same_as(op->else_case)) {
612
      return tvm::ffi::GetRef<Stmt>(op);
613
614
615
616
617
    } else {
      return IfThenElse(condition, then_case, else_case);
    }
  }
  // While
618
  Stmt VisitStmt_(const WhileNode *op) final {
619
620
621
    LOG(FATAL) << "A while loop inside a vectorized loop not supported.";
  }
  // LetStmt
622
  Stmt VisitStmt_(const LetStmtNode *op) final {
623
    PrimExpr value = this->VisitExpr(op->value);
624
625
    ICHECK(!let_binding_.count(op->var))
        << "SSA violation, a single var is binded twice";
626
627
628
629
630
631
632
633
634
635
636
    let_binding_[op->var] = value;

    if (value.dtype().get_lanes_or_vscale_factor() !=
        op->value.dtype().get_lanes_or_vscale_factor()) {
      Var new_var(op->var->name_hint, value.dtype());
      let_binding_[op->var] = new_var;
      return LetStmt(new_var, value, this->VisitStmt(op->body));
    } else {
      let_binding_[op->var] = op->var;
      Stmt body = this->VisitStmt(op->body);
      if (value.same_as(op->value) && body.same_as(op->body)) {
637
        return tvm::ffi::GetRef<Stmt>(op);
638
639
640
641
642
643
      } else {
        return LetStmt(op->var, value, body);
      }
    }
  }
  // Allocate
644
  Stmt VisitStmt_(const AllocateNode *op) final {
645
646
647
    // Mutate the condition
    PrimExpr condition = this->VisitExpr(op->condition);
    if (condition.dtype().is_scalable_or_fixed_length_vector()) {
648
649
      LOG(WARNING) << "Cannot handle vector extent in alloc of "
                   << op->buffer_var->name_hint;
650
      return Scalarize(tvm::ffi::GetRef<Stmt>(op));
651
652
653
654
    }

    // Mutate the extents
    Array<PrimExpr> extents;
655
    for (const auto &extent : op->extents) {
656
657
      PrimExpr new_ext = this->VisitExpr(extent);
      if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
658
659
        LOG(WARNING) << "Cannot handle vector extent in alloc of "
                     << op->buffer_var->name_hint;
660
        return Scalarize(tvm::ffi::GetRef<Stmt>(op));
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
      }
      extents.push_back(new_ext);
    }

    // TODO(Lunderberg): Move this pass to be prior to
    // StorageFlatten/FlattenBuffer.  That will allow this pass to be
    // implemented as adding a new buffer dimension, which is later
    // flattened.

    // Extend the least significant dimension by a factor of
    // var_lanes_.  Typically, this will be a 1-d index into a flat
    // memory space.
    extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_);

    // Rewrite access to the buffer in the body.
676
677
    Stmt body =
        VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
678
679
680
681
682
683
684
685
686
687
688
    body = this->VisitStmt(body);
    return Allocate(op->buffer_var, op->dtype, extents, condition, body);
  }

  // scalarize the statement
  Stmt Scalarize(Stmt stmt) {
    Var idx(var_->name_hint + ".s", var_->dtype);
    stmt = Substitute(stmt, {{var_, idx}});
    return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
  }

689
private:
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
  // analyzer
  arith::Analyzer analyzer_;
  // deep equal
  ExprDeepEqual deep_equal_;
  // variable to be replaced
  Var var_;
  // the lanes.
  PrimExpr var_lanes_;
  // ramp representing the var.
  PrimExpr ramp_;
  // flag to mark requirement of scalarization.
  bool need_scalarize_{false};
  // Let binding
  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
  // vectorizable property
705
706
  OpAttrMap<TVectorizable> op_vectorizable_ =
      Op::GetAttrMap<TVectorizable>("TVectorizable");
707
708
709

  // mutate array, with given lane requirement
  // when finished, p_lane updates the lane requirement.
710
  Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int *p_lanes) {
711
    if (arr.empty())
712
713
      return arr;
    int &lanes = *p_lanes;
714
715
716
717
718
    bool changed = false;
    std::vector<PrimExpr> new_arr(arr.size());
    for (size_t i = 0; i < arr.size(); i++) {
      PrimExpr old_elem = arr[i];
      PrimExpr new_elem = this->VisitExpr(old_elem);
719
720
      if (!new_elem.same_as(old_elem))
        changed = true;
721
722
723
724
725
726
727
728
729
730
      new_arr[i] = new_elem;
      lanes = std::max(lanes, new_elem.dtype().lanes());
    }

    for (size_t i = 0; i < arr.size(); ++i) {
      if (new_arr[i].dtype().lanes() != lanes) {
        new_arr[i] = BroadcastTo(new_arr[i], lanes, false);
        changed = true;
      }
    }
731
732
    if (!changed)
      return arr;
733
734
    return Array<PrimExpr>(new_arr);
  }
735
736
737
  template <typename TOp, typename T> PrimExpr BinaryVec(const T *op) {
    static_assert(std::is_same<typename TOp::ContainerType, T>::value,
                  "constraint");
738
739
740
    PrimExpr a = this->VisitExpr(op->a);
    PrimExpr b = this->VisitExpr(op->b);
    if (a.same_as(op->a) && b.same_as(op->b)) {
741
      return tvm::ffi::GetRef<PrimExpr>(op);
742
743
744
745
    } else {
      int a_lanes = a.dtype().get_lanes_or_vscale_factor();
      int b_lanes = b.dtype().get_lanes_or_vscale_factor();
      int lanes = std::max(a_lanes, b_lanes);
746
747
748
749
      bool is_scalable =
          a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
      return TOp(BroadcastTo(a, lanes, is_scalable),
                 BroadcastTo(b, lanes, is_scalable));
750
751
752
    }
  }
  template <typename T, typename FCompute>
753
  PrimExpr AddSubVec(const T *op, FCompute fcompute) {
754
755
756
    PrimExpr a = this->VisitExpr(op->a);
    PrimExpr b = this->VisitExpr(op->b);
    if (a.same_as(op->a) && b.same_as(op->b)) {
757
      return tvm::ffi::GetRef<PrimExpr>(op);
758
759
760
761
762
    } else {
      int a_lanes = a.dtype().get_lanes_or_vscale_factor();
      int b_lanes = b.dtype().get_lanes_or_vscale_factor();
      int lanes = std::max(a_lanes, b_lanes);
      if (lanes != 1) {
763
764
        const RampNode *b_ramp = b.as<RampNode>();
        const RampNode *a_ramp = a.as<RampNode>();
765
        if (a.dtype().is_scalar() && b_ramp) {
766
767
768
769
          return Ramp(
              fcompute(a, b_ramp->base),
              fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride),
              b_ramp->lanes);
770
771
772
773
774
        }
        if (b.dtype().is_scalar() && a_ramp) {
          return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
        }
      }
775
776
777
778
      bool is_scalable =
          a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
      return fcompute(BroadcastTo(a, lanes, is_scalable),
                      BroadcastTo(b, lanes, is_scalable));
779
780
781
782
    }
  }
};

783
784
} // namespace tl
} // namespace tvm