/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file common.h * \brief Common utilities for TL transforms */ #include #include #include #include #include #include #include #include "../../op/parallel.h" #include "../loop_partition.h" #include "../loop_vectorize.h" #include "arith/ir_mutator_with_analyzer.h" namespace tvm { namespace tl { using namespace tir; // Vectorize Part // Use the same code as tir.transform.vectorize_loop inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) { if (is_scalable) { return Mul(Call(DataType::Int(32), builtin::vscale(), {}), lanes_or_vscale_factor); } else { return lanes_or_vscale_factor; } } inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { // Check if e is already in the expected form if (e.dtype().get_lanes_or_vscale_factor() == lanes && e.dtype().is_scalable_vector() == is_scalable) return e; if (const BroadcastNode *op = e.as()) { ICHECK(op->dtype.is_scalable_vector() == is_scalable) << "Can't broadcast between scalable and fixed length vectors."; int e_lanes = op->dtype.get_lanes_or_vscale_factor(); if (lanes % e_lanes == 0) { return Broadcast(op->value, CreateNewLanes(is_scalable, lanes)); } } ICHECK(e.dtype().is_scalar()) << "Cannot broadcast lanes=" << e.dtype().get_lanes_or_vscale_factor() << " is_scalable=" << e.dtype().is_scalable_vector() << " to " << lanes; return Broadcast(e, CreateNewLanes(is_scalable, lanes)); } // Rewrite vectorized allocation access // This is necessary for making each vector component containing its own // workspace. Originates from Halide's loop vectorizer // // s[i] = s[i * lanes + var] // // The same principle applies when using one thread to simulate multiple // context. // class VecAllocAccess : public StmtExprMutator { public: VecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes) : buf_(buf), var_(std::move(var)), var_lanes_(std::move(var_lanes)) {} PrimExpr VisitExpr_(const BufferLoadNode *op) final { auto load = Downcast(StmtExprMutator::VisitExpr_(op)); return UpdateBufferAccess(load); } Stmt VisitStmt_(const BufferStoreNode *op) final { auto store = Downcast(StmtExprMutator::VisitStmt_(op)); return UpdateBufferAccess(store); } private: template Node UpdateBufferAccess(Node node) { // Only update the buffer that's being replaced. if (node->buffer->data.get() != buf_) { return node; } // Find/make a Buffer object with the correct updated shape. Buffer buf; auto it = buffer_map_.find(node->buffer.get()); if (it != buffer_map_.end()) { buf = it->second; } else { // Extend the least significant dimension by a factor of // var_lanes_. Typically, this will be a 1-d index into a flat // memory space. Array shape = node->buffer->shape; shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_)); // TODO(Lunderberg): Move this pass to be prior to // StorageFlatten/FlattenBuffer, implement by appending a // dimension to the buffer. Since it is currently after the // flattening, the strides are not technically necessary, but // are updated for consistency. // Update strides if defined. Array strides; for (size_t i = 0; i < strides.size(); i++) { PrimExpr stride = strides[i]; if (i != strides.size() - 1) { stride *= var_lanes_; } strides.push_back(analyzer_.Simplify(stride)); } // Copy everything into the new buffer. buf = node->buffer; auto buf_writer = buf.CopyOnWrite(); buf_writer->shape = shape; buf_writer->strides = strides; buffer_map_[buf.get()] = buf; } // Extend the last index by the number of lanes in the vectorized // variable. Array indices = node->indices; indices.Set( indices.size() - 1, analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); auto writer = node.CopyOnWrite(); writer->buffer = buf; writer->indices = indices; return node; } // buffer var const VarNode *buf_; // Updated buffer objects. std::unordered_map buffer_map_; // variable to be replaced Var var_; // the lanes. PrimExpr var_lanes_; // Analyzer for simplifications arith::Analyzer analyzer_; }; // We use ExprFunctor directly instead of StmtExprMutator // This is because the transformation can change the dtype of the Expr // The existing ExprMutator transformation rules may not be well defined. class Vectorizer : public StmtMutator, public ExprFunctor { public: using ExprFunctor::VisitExpr; using StmtMutator::operator(); Vectorizer(const Var &var, const PrimExpr &var_lanes) : var_(var), var_lanes_(var_lanes) { ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes); } Stmt VisitStmt(const Stmt &stmt) final { ICHECK(!need_scalarize_); Stmt ret = StmtMutator::VisitStmt(stmt); if (need_scalarize_) { need_scalarize_ = false; return Scalarize(stmt); } else { return ret; } } PrimExpr VisitExpr(const PrimExpr &e) final { return ExprFunctor::VisitExpr(e); } PrimExpr VisitExpr_(const AddNode *op) final { return AddSubVec( op, [](PrimExpr a, PrimExpr b) { return std::move(a) + std::move(b); }); } PrimExpr VisitExpr_(const SubNode *op) final { return AddSubVec( op, [](PrimExpr a, PrimExpr b) { return std::move(a) - std::move(b); }); } PrimExpr VisitExpr_(const MulNode *op) final { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); if (is_vec_a && is_vec_b) { // Let's not multiply scalable and fixed length vectors ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector()) << "Fixed length and scalable vectors can't be mixed in " "multiplication."; } if (is_vec_a || is_vec_b) { const RampNode *b_ramp = b.as(); const RampNode *a_ramp = a.as(); if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) { PrimExpr lanes = a_ramp->lanes; return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes); } if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) { PrimExpr lanes = b_ramp->lanes; return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes); } int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int max_lanes = std::max(a_lanes, b_lanes); bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); return Mul(BroadcastTo(a, max_lanes, is_scalable), BroadcastTo(b, max_lanes, is_scalable)); } } return BinaryVec(op); } PrimExpr VisitExpr_(const DivNode *op) final { return BinaryVec
(op); } PrimExpr VisitExpr_(const ModNode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const FloorDivNode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const FloorModNode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const MinNode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const MaxNode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const EQNode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const NENode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const LTNode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const LENode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const GTNode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const GENode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const AndNode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const OrNode *op) final { return BinaryVec(op); } PrimExpr VisitExpr_(const NotNode *op) final { PrimExpr a = this->VisitExpr(op->a); if (a.same_as(op->a)) { return GetRef(op); } else { return !(a); } } PrimExpr VisitExpr_(const RampNode *op) final { PrimExpr base = this->VisitExpr(op->base); PrimExpr stride = this->VisitExpr(op->stride); ICHECK(!base.dtype().is_scalable_vector()) << "Creating scalable vectors from existing vectors is not supported."; ICHECK(!stride.dtype().is_scalable_vector()) << "Ramp stride with scalable dtype is not supported"; if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) { ICHECK(op->lanes->IsInstance()) << "Vectorizing over existing scalable vectors is not supported."; const RampNode *base_ramp = base.as(); int op_lanes = static_cast(Downcast(op->lanes)->value); int base_ramp_lanes = static_cast(Downcast(base_ramp->lanes)->value); if (analyzer_.CanProve(base_ramp->stride == stride * make_const(stride.dtype(), base_ramp_lanes))) { return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes); } } int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes()); base = BroadcastTo(base, lanes, false); stride = BroadcastTo(stride, lanes, false); Array elems; for (int i = 0; i < lanes; ++i) { elems.push_back(Ramp(Shuffle::ExtractElement(base, i), Shuffle::ExtractElement(stride, i), op->lanes)); } return Shuffle::Concat(elems); } PrimExpr VisitExpr_(const BroadcastNode *op) final { PrimExpr value = this->VisitExpr(op->value); if (value.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; return GetRef(op); } if (value.same_as(op->value)) { return GetRef(op); } else { return Broadcast(op->value, op->lanes); } } PrimExpr VisitExpr_(const SelectNode *op) final { PrimExpr cond = this->VisitExpr(op->condition); PrimExpr t = this->VisitExpr(op->true_value); PrimExpr f = this->VisitExpr(op->false_value); if (cond.same_as(op->condition) && t.same_as(op->true_value) && f.same_as(op->false_value)) { return GetRef(op); } else { int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor(); int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes); bool is_scalable = cond.dtype().is_scalable_vector() || t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector(); return Select(BroadcastTo(cond, lanes, is_scalable), BroadcastTo(t, lanes, is_scalable), BroadcastTo(f, lanes, is_scalable)); } } PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { if (value.dtype().is_scalable_vector()) { return Cast(op->dtype.with_scalable_vscale_factor( value.dtype().vscale_factor()), value); } else { return Cast(op->dtype.with_lanes(value.dtype().lanes()), value); } } } PrimExpr VisitExpr_(const FloatImmNode *op) final { return GetRef(op); } PrimExpr VisitExpr_(const IntImmNode *op) final { return GetRef(op); } PrimExpr VisitExpr_(const StringImmNode *op) final { return GetRef(op); } // Variable PrimExpr VisitExpr_(const VarNode *op) final { Var var = GetRef(op); if (var.same_as(var_)) { return ramp_; } auto it = let_binding_.find(var); if (it != let_binding_.end()) { return it->second; } else { return std::move(var); } } // IfThenElse expr PrimExpr MutateIfThenElseExpr_(const CallNode *op) { PrimExpr cond = this->VisitExpr(op->args[0]); if (cond.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; return GetRef(op); } PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr f = this->VisitExpr(op->args[2]); if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && f.same_as(op->args[2])) { return GetRef(op); } else { int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor(); int lanes = std::max(t_lanes, f_lanes); bool is_scalable = t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector(); t = BroadcastTo(t, lanes, is_scalable); f = BroadcastTo(f, lanes, is_scalable); if (is_scalable) { return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {cond, t, f}); } else { return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f}); } } } // Reinterpret expr PrimExpr MutateReinterpretExpr_(const CallNode *op) { ICHECK(op->op.same_as(builtin::reinterpret())); PrimExpr value = this->VisitExpr(op->args[0]); if (value.same_as(op->args[0])) { return GetRef(op); } else { int lanes = value.dtype().get_lanes_or_vscale_factor(); if (value.dtype().is_scalable_vector()) { return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {value}); } else { return Call(op->dtype.with_lanes(lanes), op->op, {value}); } } } // Call PrimExpr VisitExpr_(const CallNode *op) final { if (op->op.same_as(builtin::if_then_else())) { return MutateIfThenElseExpr_(op); } else if (op->op.same_as(builtin::texture2d_load())) { int lane = 0; Array fcd = MutateArray({op->args.back()}, &lane); auto new_args = op->args; new_args.pop_back(); new_args.push_back(fcd[0]); return Call(op->dtype.with_lanes(4), op->op, new_args); } else if (op->op.same_as(builtin::texture2d_store())) { int lane = 0; // Vectorize the value to store Array value{op->args.back()}; Array mutated_value = MutateArray(value, &lane); Array new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]}; return Call(op->dtype.with_lanes(lane), op->op, new_args); } else if (op->op.same_as(builtin::reinterpret())) { return MutateReinterpretExpr_(op); } auto optional_op = op->op.as(); bool vectorizable = optional_op && op_vectorizable_.get(optional_op.value(), false) && !op->dtype.is_scalable_vector(); if (!vectorizable) { // Cannot vectorize this op Array new_args; for (auto arg : op->args) { auto new_arg = this->VisitExpr(arg); if (new_arg.dtype().is_scalable_or_fixed_length_vector()) { need_scalarize_ = true; return GetRef(op); } new_args.push_back(new_arg); } if (op->args.same_as(new_args)) { return GetRef(op); } else { return Call(op->dtype, op->op, new_args); } } else { int lane = 0; Array new_args = MutateArray(op->args, &lane); // normal code path. if (op->args.same_as(new_args)) { return GetRef(op); } else { return Call(op->dtype.with_lanes(lane), op->op, new_args); } } } // BufferLoad PrimExpr VisitExpr_(const BufferLoadNode *op) final { auto load = GetRef(op); auto fmutate = [this](const PrimExpr &index) { return this->VisitExpr(index); }; Array indices = op->indices.Map(fmutate); if (!indices.same_as(op->indices)) { auto writer = load.CopyOnWrite(); writer->indices = indices; } return std::move(load); } // Let PrimExpr VisitExpr_(const LetNode *op) final { PrimExpr value = this->VisitExpr(op->value); // Weaker SSA condition // A single var can be binded in multiple lets // but they have to bind to the same value. // This is used to allow cases when we reuse a single let // expression to construct a nested expr. // (let x = 1 in x + 1) * (let x = 1 in x + 1) auto it = let_binding_.find(op->var); if (it != let_binding_.end()) { ICHECK(deep_equal_(it->second, value)) << "Let cannot bind the same var to two different values"; } if (value.dtype().get_lanes_or_vscale_factor() != op->value.dtype().get_lanes_or_vscale_factor()) { Var new_var(op->var->name_hint, value.dtype()); let_binding_[op->var] = new_var; return Let(new_var, value, this->VisitExpr(op->body)); } else { let_binding_[op->var] = op->var; PrimExpr body = this->VisitExpr(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return Let(op->var, value, body); } } } // BufferStore Stmt VisitStmt_(const BufferStoreNode *op) final { auto store = GetRef(op); auto fmutate = [this](const PrimExpr &index) { return this->VisitExpr(index); }; Array indices = op->indices.Map(fmutate); PrimExpr value = this->VisitExpr(op->value); if (!indices.same_as(op->indices) || !value.same_as(op->value)) { ICHECK(!op->buffer->dtype.is_scalable_vector()) << "Vectorizing over scalable buffer elements is not supported in " "vectorizer."; // How many lanes of indexing are present in the index and // buffer element type, excluding the last index. int other_index_lanes = op->buffer->dtype.lanes(); for (size_t i = 0; i < indices.size() - 1; i++) { other_index_lanes *= indices[i].dtype().lanes(); // Only allow the last index to be scalable ICHECK(!indices[i].dtype().is_scalable_vector()) << "Only the last index can be scalable."; } // The total number of lanes of indexing, including the last index. auto last_index_dtype = indices[indices.size() - 1].dtype(); int lanes_in_last_index = last_index_dtype.get_lanes_or_vscale_factor(); int index_lanes = other_index_lanes * lanes_in_last_index; // The total number of lanes in this store operation. Either // the index or the value will be broadcast out to this number // of lanes, depending on which has more lanes. int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor(); bool is_last_index_scalable = last_index_dtype.is_scalable_vector(); int total_lanes = std::max(index_lanes, value_dtype_lanes); ICHECK_EQ(total_lanes % other_index_lanes, 0) << "When storing to buffer " << op->buffer->name << ", cannot produce " << total_lanes << " lanes of storage location by changing the last index."; int last_index_lanes = total_lanes / other_index_lanes; // Broadcast the last index such that the total number of index // lanes matches the desired number. indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1], last_index_lanes, is_last_index_scalable)); auto writer = store.CopyOnWrite(); writer->indices = indices; writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable); } return std::move(store); } // For Stmt VisitStmt_(const ForNode *op) final { if (op->kind == ForKind::kVectorized) { LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring..."; } ICHECK(is_zero(op->min)); ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector()); PrimExpr extent = this->VisitExpr(op->extent); if (extent.dtype().is_scalable_or_fixed_length_vector()) { return Scalarize(GetRef(op)); } Stmt body = this->VisitStmt(op->body); if (extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding, op->annotations); } } // IfThenElse Stmt VisitStmt_(const IfThenElseNode *op) final { ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); PrimExpr condition = this->VisitExpr(op->condition); if (condition.dtype().is_scalable_or_fixed_length_vector()) { return Scalarize(GetRef(op)); } Stmt then_case = this->VisitStmt(op->then_case); Optional else_case = std::nullopt; if (op->else_case) { else_case = this->VisitStmt(op->else_case.value()); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { return GetRef(op); } else { return IfThenElse(condition, then_case, else_case); } } // While Stmt VisitStmt_(const WhileNode *op) final { LOG(FATAL) << "A while loop inside a vectorized loop not supported."; } // LetStmt Stmt VisitStmt_(const LetStmtNode *op) final { PrimExpr value = this->VisitExpr(op->value); ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice"; let_binding_[op->var] = value; if (value.dtype().get_lanes_or_vscale_factor() != op->value.dtype().get_lanes_or_vscale_factor()) { Var new_var(op->var->name_hint, value.dtype()); let_binding_[op->var] = new_var; return LetStmt(new_var, value, this->VisitStmt(op->body)); } else { let_binding_[op->var] = op->var; Stmt body = this->VisitStmt(op->body); if (value.same_as(op->value) && body.same_as(op->body)) { return GetRef(op); } else { return LetStmt(op->var, value, body); } } } // Allocate Stmt VisitStmt_(const AllocateNode *op) final { // Mutate the condition PrimExpr condition = this->VisitExpr(op->condition); if (condition.dtype().is_scalable_or_fixed_length_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; return Scalarize(GetRef(op)); } // Mutate the extents Array extents; for (const auto &extent : op->extents) { PrimExpr new_ext = this->VisitExpr(extent); if (new_ext.dtype().is_scalable_or_fixed_length_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc of " << op->buffer_var->name_hint; return Scalarize(GetRef(op)); } extents.push_back(new_ext); } // TODO(Lunderberg): Move this pass to be prior to // StorageFlatten/FlattenBuffer. That will allow this pass to be // implemented as adding a new buffer dimension, which is later // flattened. // Extend the least significant dimension by a factor of // var_lanes_. Typically, this will be a 1-d index into a flat // memory space. extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_); // Rewrite access to the buffer in the body. Stmt body = VecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body); body = this->VisitStmt(body); return Allocate(op->buffer_var, op->dtype, extents, condition, body); } // scalarize the statement Stmt Scalarize(Stmt stmt) { Var idx(var_->name_hint + ".s", var_->dtype); stmt = Substitute(stmt, {{var_, idx}}); return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt); } private: // analyzer arith::Analyzer analyzer_; // deep equal ExprDeepEqual deep_equal_; // variable to be replaced Var var_; // the lanes. PrimExpr var_lanes_; // ramp representing the var. PrimExpr ramp_; // flag to mark requirement of scalarization. bool need_scalarize_{false}; // Let binding std::unordered_map let_binding_; // vectorizable property OpAttrMap op_vectorizable_ = Op::GetAttrMap("TVectorizable"); // mutate array, with given lane requirement // when finished, p_lane updates the lane requirement. Array MutateArray(Array arr, int *p_lanes) { if (arr.empty()) return arr; int &lanes = *p_lanes; bool changed = false; std::vector new_arr(arr.size()); for (size_t i = 0; i < arr.size(); i++) { PrimExpr old_elem = arr[i]; PrimExpr new_elem = this->VisitExpr(old_elem); if (!new_elem.same_as(old_elem)) changed = true; new_arr[i] = new_elem; lanes = std::max(lanes, new_elem.dtype().lanes()); } for (size_t i = 0; i < arr.size(); ++i) { if (new_arr[i].dtype().lanes() != lanes) { new_arr[i] = BroadcastTo(new_arr[i], lanes, false); changed = true; } } if (!changed) return arr; return Array(new_arr); } template PrimExpr BinaryVec(const T *op) { static_assert(std::is_same::value, "constraint"); PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int lanes = std::max(a_lanes, b_lanes); bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); return TOp(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } } template PrimExpr AddSubVec(const T *op, FCompute fcompute) { PrimExpr a = this->VisitExpr(op->a); PrimExpr b = this->VisitExpr(op->b); if (a.same_as(op->a) && b.same_as(op->b)) { return GetRef(op); } else { int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int lanes = std::max(a_lanes, b_lanes); if (lanes != 1) { const RampNode *b_ramp = b.as(); const RampNode *a_ramp = a.as(); if (a.dtype().is_scalar() && b_ramp) { return Ramp( fcompute(a, b_ramp->base), fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride), b_ramp->lanes); } if (b.dtype().is_scalar() && a_ramp) { return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes); } } bool is_scalable = a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector(); return fcompute(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, is_scalable)); } } }; } // namespace tl } // namespace tvm