Commit 7111239d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Dev] Separate `LoopVectorize` Pass from upstream tvm (#62)

* [Enhancement] Add VectorizeLoop function and update imports for compatibility

* [CI][Test] Improve test cases for vectorization and fix typos in parser comments

* lint fix

* Fix incorrect module reference for VectorizeLoop transformation

* Refactor vectorize_loop transformation by removing unused extent mutation logic
parent ea612446
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file vectorize_loop.cc
*/
// Loop vectorizer as in Halide pipeline.
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <vector>
#include "arith/scalable_expression.h"
#include "tir/analysis/check_contains.h"
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Perform data type legalization on the given BufferLoadNode pointer.
* Equal to BufferLoadNode::LegalizeDType, but operates on a pointer.
* \param n A pointer to a writable BufferLoadNode.
*/
static void LegalizeBufferLoadDType(BufferLoadNode *n) {
// Check that all indices except the last one have a scalar dtype
for (int i = 0; i < static_cast<int>(n->indices.size()) - 1; i++) {
ICHECK(n->indices[i].dtype().is_scalar())
<< "Only the last index of a buffer access may be a vector type.";
}
// If there are no indices, set the dtype to the buffer's dtype
if (n->indices.empty()) {
n->dtype = n->buffer->dtype;
} else {
auto index_dtype = n->indices.back().dtype();
bool is_buffer_dtype_scalable = n->buffer->dtype.is_scalable_vector();
bool is_index_scalable = index_dtype.is_scalable_vector();
// Do not allow both index dtype and buffer dtype to be scalable vectors
ICHECK(!(is_index_scalable && is_buffer_dtype_scalable))
<< "Index dtype and buffer dtype cannot both be scalable.";
if (is_index_scalable) {
// Index is a scalable vector, while the buffer is not
n->dtype = n->buffer->dtype.with_scalable_vscale_factor(
index_dtype.vscale_factor() * n->buffer->dtype.lanes());
} else if (is_buffer_dtype_scalable) {
// The buffer is a scalable vector, while the index is not
n->dtype = n->buffer->dtype.with_scalable_vscale_factor(
n->buffer->dtype.vscale_factor() * index_dtype.lanes());
} else {
// Neither side is a scalable vector, multiply lanes
n->dtype = n->buffer->dtype.with_lanes(index_dtype.lanes() *
n->buffer->dtype.lanes());
}
}
}
inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
if (is_scalable) {
return Mul(Call(DataType::Int(32), builtin::vscale(), {}),
lanes_or_vscale_factor);
} else {
return lanes_or_vscale_factor;
}
}
inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
// Check if e is already in the expected form
if (e.dtype().get_lanes_or_vscale_factor() == lanes &&
e.dtype().is_scalable_vector() == is_scalable)
return e;
if (const BroadcastNode *op = e.as<BroadcastNode>()) {
ICHECK(op->dtype.is_scalable_vector() == is_scalable)
<< "Can't broadcast between scalable and fixed length vectors.";
int e_lanes = op->dtype.get_lanes_or_vscale_factor();
if (lanes % e_lanes == 0) {
return Broadcast(op->value, CreateNewLanes(is_scalable, lanes));
}
}
ICHECK(e.dtype().is_scalar())
<< "Cannot broadcast lanes=" << e.dtype().get_lanes_or_vscale_factor()
<< " is_scalable=" << e.dtype().is_scalable_vector() << " to " << lanes;
return Broadcast(e, CreateNewLanes(is_scalable, lanes));
}
// Rewrite vectorized allocation access
// This is necessary for making each vector component containing its own
// workspace. Originates from Halide's loop vectorizer
//
// s[i] = s[i * lanes + var]
//
// The same principle applies when using one thread to simulate multiple
// context.
//
class VecAllocAccess : public StmtExprMutator {
public:
VecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return UpdateBufferAccess(load);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
return UpdateBufferAccess(store);
}
private:
template <typename Node> Node UpdateBufferAccess(Node node) {
// Only update the buffer that's being replaced.
if (node->buffer->data.get() != buf_) {
return node;
}
// Find/make a Buffer object with the correct updated shape.
Buffer buf;
auto it = buffer_map_.find(node->buffer.get());
if (it != buffer_map_.end()) {
buf = it->second;
} else {
// Extend the least significant dimension by a factor of
// var_lanes_. Typically, this will be a 1-d index into a flat
// memory space.
Array<PrimExpr> shape = node->buffer->shape;
shape.Set(shape.size() - 1,
analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_));
// TODO(Lunderberg): Move this pass to be prior to
// StorageFlatten/FlattenBuffer, implement by appending a
// dimension to the buffer. Since it is currently after the
// flattening, the strides are not technically necessary, but
// are updated for consistency.
// Update strides if defined.
Array<PrimExpr> strides;
for (size_t i = 0; i < strides.size(); i++) {
PrimExpr stride = strides[i];
if (i != strides.size() - 1) {
stride *= var_lanes_;
}
strides.push_back(analyzer_.Simplify(stride));
}
// Copy everything into the new buffer.
buf = node->buffer;
auto buf_writer = buf.CopyOnWrite();
buf_writer->shape = shape;
buf_writer->strides = strides;
buffer_map_[buf.get()] = buf;
}
// Extend the last index by the number of lanes in the vectorized
// variable.
Array<PrimExpr> indices = node->indices;
indices.Set(
indices.size() - 1,
analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_));
auto writer = node.CopyOnWrite();
writer->buffer = buf;
writer->indices = indices;
return node;
}
// buffer var
const VarNode *buf_;
// Updated buffer objects.
std::unordered_map<const BufferNode *, Buffer> buffer_map_;
// variable to be replaced
Var var_;
// the lanes.
PrimExpr var_lanes_;
// Analyzer for simplifications
arith::Analyzer analyzer_;
};
// We use ExprFunctor directly instead of StmtExprMutator
// This is because the transformation can change the dtype of the Expr
// The existing ExprMutator transformation rules may not be well defined.
class TLVectorizer : public StmtMutator,
public ExprFunctor<PrimExpr(const PrimExpr &)> {
public:
using ExprFunctor::VisitExpr;
using StmtMutator::operator();
TLVectorizer(Var var, PrimExpr var_lanes) : var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
}
Stmt VisitStmt(const Stmt &stmt) final {
ICHECK(!need_scalarize_);
Stmt ret = StmtMutator::VisitStmt(stmt);
if (need_scalarize_) {
need_scalarize_ = false;
return Scalarize(stmt);
} else {
return ret;
}
}
PrimExpr VisitExpr(const PrimExpr &e) final {
return ExprFunctor::VisitExpr(e);
}
PrimExpr VisitExpr_(const AddNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a + b; });
}
PrimExpr VisitExpr_(const SubNode *op) final {
return AddSubVec(op, [](PrimExpr a, PrimExpr b) { return a - b; });
}
PrimExpr VisitExpr_(const MulNode *op) final {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
if (is_vec_a && is_vec_b) {
// Let's not multiply scalable and fixed length vectors
ICHECK(a.dtype().is_scalable_vector() == b.dtype().is_scalable_vector())
<< "Fixed length and scalable vectors can't be mixed in "
"multiplication.";
}
if (is_vec_a || is_vec_b) {
const RampNode *b_ramp = b.as<RampNode>();
const RampNode *a_ramp = a.as<RampNode>();
if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) {
PrimExpr lanes = a_ramp->lanes;
return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes);
}
if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) {
PrimExpr lanes = b_ramp->lanes;
return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes);
}
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int max_lanes = std::max(a_lanes, b_lanes);
bool is_scalable =
a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return Mul(BroadcastTo(a, max_lanes, is_scalable),
BroadcastTo(b, max_lanes, is_scalable));
}
}
return BinaryVec<Mul>(op);
}
PrimExpr VisitExpr_(const DivNode *op) final { return BinaryVec<Div>(op); }
PrimExpr VisitExpr_(const ModNode *op) final { return BinaryVec<Mod>(op); }
PrimExpr VisitExpr_(const FloorDivNode *op) final {
return BinaryVec<FloorDiv>(op);
}
PrimExpr VisitExpr_(const FloorModNode *op) final {
return BinaryVec<FloorMod>(op);
}
PrimExpr VisitExpr_(const MinNode *op) final { return BinaryVec<Min>(op); }
PrimExpr VisitExpr_(const MaxNode *op) final { return BinaryVec<Max>(op); }
PrimExpr VisitExpr_(const EQNode *op) final { return BinaryVec<EQ>(op); }
PrimExpr VisitExpr_(const NENode *op) final { return BinaryVec<NE>(op); }
PrimExpr VisitExpr_(const LTNode *op) final { return BinaryVec<LT>(op); }
PrimExpr VisitExpr_(const LENode *op) final { return BinaryVec<LE>(op); }
PrimExpr VisitExpr_(const GTNode *op) final { return BinaryVec<GT>(op); }
PrimExpr VisitExpr_(const GENode *op) final { return BinaryVec<GE>(op); }
PrimExpr VisitExpr_(const AndNode *op) final { return BinaryVec<And>(op); }
PrimExpr VisitExpr_(const OrNode *op) final { return BinaryVec<Or>(op); }
PrimExpr VisitExpr_(const NotNode *op) final {
PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op);
} else {
return !(a);
}
}
PrimExpr VisitExpr_(const RampNode *op) final {
PrimExpr base = this->VisitExpr(op->base);
PrimExpr stride = this->VisitExpr(op->stride);
ICHECK(!base.dtype().is_scalable_vector())
<< "Creating scalable vectors from existing vectors is not supported.";
ICHECK(!stride.dtype().is_scalable_vector())
<< "Ramp stride with scalable dtype is not supported";
if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) {
ICHECK(op->lanes->IsInstance<IntImmNode>())
<< "Vectorizing over existing scalable vectors is not supported.";
const RampNode *base_ramp = base.as<RampNode>();
int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
int base_ramp_lanes =
static_cast<int>(Downcast<IntImm>(base_ramp->lanes)->value);
if (analyzer_.CanProve(base_ramp->stride ==
stride *
make_const(stride.dtype(), base_ramp_lanes))) {
return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes);
}
}
int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
base = BroadcastTo(base, lanes, false);
stride = BroadcastTo(stride, lanes, false);
Array<PrimExpr> elems;
for (int i = 0; i < lanes; ++i) {
elems.push_back(Ramp(Shuffle::ExtractElement(base, i),
Shuffle::ExtractElement(stride, i), op->lanes));
}
return Shuffle::Concat(elems);
}
PrimExpr VisitExpr_(const BroadcastNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
}
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
return Broadcast(op->value, op->lanes);
}
}
PrimExpr VisitExpr_(const SelectNode *op) final {
PrimExpr cond = this->VisitExpr(op->condition);
PrimExpr t = this->VisitExpr(op->true_value);
PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
} else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes);
bool is_scalable = cond.dtype().is_scalable_vector() ||
t.dtype().is_scalable_vector() ||
f.dtype().is_scalable_vector();
return Select(BroadcastTo(cond, lanes, is_scalable),
BroadcastTo(t, lanes, is_scalable),
BroadcastTo(f, lanes, is_scalable));
}
}
PrimExpr VisitExpr_(const CastNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor(
value.dtype().vscale_factor()),
value);
} else {
return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
}
}
}
PrimExpr VisitExpr_(const FloatImmNode *op) final {
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const StringImmNode *op) final {
return GetRef<PrimExpr>(op);
}
// Variable
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op);
if (var.same_as(var_)) {
return ramp_;
}
auto it = let_binding_.find(var);
if (it != let_binding_.end()) {
return it->second;
} else {
return std::move(var);
}
}
// IfThenElse expr
PrimExpr MutateIfThenElseExpr_(const CallNode *op) {
PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
}
PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
} else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(t_lanes, f_lanes);
bool is_scalable =
t.dtype().is_scalable_vector() || f.dtype().is_scalable_vector();
t = BroadcastTo(t, lanes, is_scalable);
f = BroadcastTo(f, lanes, is_scalable);
if (is_scalable) {
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{cond, t, f});
} else {
return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
}
}
}
// Reinterpret expr
PrimExpr MutateReinterpretExpr_(const CallNode *op) {
ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op);
} else {
int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) {
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{value});
} else {
return Call(op->dtype.with_lanes(lanes), op->op, {value});
}
}
}
// Call
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::if_then_else())) {
return MutateIfThenElseExpr_(op);
} else if (op->op.same_as(builtin::texture2d_load())) {
int lane = 0;
Array<PrimExpr> fcd = MutateArray({op->args.back()}, &lane);
auto new_args = op->args;
new_args.pop_back();
new_args.push_back(fcd[0]);
return Call(op->dtype.with_lanes(4), op->op, new_args);
} else if (op->op.same_as(builtin::texture2d_store())) {
int lane = 0;
// Vectorize the value to store
Array<PrimExpr> value{op->args.back()};
Array<PrimExpr> mutated_value = MutateArray(value, &lane);
Array<PrimExpr> new_args{op->args[0], op->args[1], op->args[2],
mutated_value[0]};
return Call(op->dtype.with_lanes(lane), op->op, new_args);
} else if (op->op.same_as(builtin::reinterpret())) {
return MutateReinterpretExpr_(op);
}
auto optional_op = op->op.as<Op>();
bool vectorizable = optional_op &&
op_vectorizable_.get(optional_op.value(), false) &&
!op->dtype.is_scalable_vector();
if (!vectorizable) {
// Cannot vectorize this op
Array<PrimExpr> new_args;
for (auto arg : op->args) {
auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
}
new_args.push_back(new_arg);
}
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
} else {
return Call(op->dtype, op->op, new_args);
}
} else {
int lane = 0;
Array<PrimExpr> new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
} else {
return Call(op->dtype.with_lanes(lane), op->op, new_args);
}
}
}
// BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
};
Array<PrimExpr> indices = op->indices.Map(fmutate);
if (!indices.same_as(op->indices)) {
BufferLoadNode *writer = load.CopyOnWrite();
writer->indices = indices;
// writer->LegalizeDType();
LegalizeBufferLoadDType(writer);
}
return std::move(load);
}
// Let
PrimExpr VisitExpr_(const LetNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
// Weaker SSA condition
// A single var can be binded in multiple lets
// but they have to bind to the same value.
// This is used to allow cases when we reuse a single let
// expression to cosntruct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
ICHECK(deep_equal_(it->second, value))
<< "Let cannot bind the same var to two different values";
}
if (value.dtype().get_lanes_or_vscale_factor() !=
op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var;
return Let(new_var, value, this->VisitExpr(op->body));
} else {
let_binding_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
} else {
return Let(op->var, value, body);
}
}
}
// BufferStore
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
};
Array<PrimExpr> indices = op->indices.Map(fmutate);
PrimExpr value = this->VisitExpr(op->value);
if (!indices.same_as(op->indices) || !value.same_as(op->value)) {
ICHECK(!op->buffer->dtype.is_scalable_vector())
<< "Vectorizing over scalable buffer elements is not supported in "
"vectorizer.";
// How many lanes of indexing are present in the index and
// buffer element type, excluding the last index.
int other_index_lanes = op->buffer->dtype.lanes();
for (size_t i = 0; i < indices.size() - 1; i++) {
other_index_lanes *= indices[i].dtype().lanes();
// Only allow the last index to be scalable
ICHECK(!indices[i].dtype().is_scalable_vector())
<< "Only the last index can be scalable.";
}
// The total number of lanes of indexing, including the last index.
auto last_index_dtype = indices[indices.size() - 1].dtype();
int lanes_in_last_index = last_index_dtype.get_lanes_or_vscale_factor();
int index_lanes = other_index_lanes * lanes_in_last_index;
// The total number of lanes in this store operation. Either
// the index or the value will be broadcast out to this number
// of lanes, depending on which has more lanes.
int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor();
bool is_last_index_scalable = last_index_dtype.is_scalable_vector();
int total_lanes = std::max(index_lanes, value_dtype_lanes);
ICHECK_EQ(total_lanes % other_index_lanes, 0)
<< "When storing to buffer " << op->buffer->name
<< ", cannot produce " << total_lanes
<< " lanes of storage location by changing the last index.";
int last_index_lanes = total_lanes / other_index_lanes;
// Broadcast the last index such that the total number of index
// lanes matches the desired number.
indices.Set(indices.size() - 1,
BroadcastTo(indices[indices.size() - 1], last_index_lanes,
is_last_index_scalable));
auto writer = store.CopyOnWrite();
writer->indices = indices;
writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable);
}
return std::move(store);
}
// For
Stmt VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kVectorized) {
LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
}
ICHECK(is_zero(op->min));
ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
PrimExpr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
}
Stmt body = this->VisitStmt(op->body);
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding, op->annotations);
}
}
// IfThenElse
Stmt VisitStmt_(const IfThenElseNode *op) final {
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = NullOpt;
if (op->else_case) {
else_case = this->VisitStmt(op->else_case.value());
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
} else {
return IfThenElse(condition, then_case, else_case);
}
}
// While
Stmt VisitStmt_(const WhileNode *op) final {
LOG(FATAL) << "A while loop inside a vectorized loop not supported.";
}
// LetStmt
Stmt VisitStmt_(const LetStmtNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
ICHECK(!let_binding_.count(op->var))
<< "SSA violation, a single var is binded twice";
let_binding_[op->var] = value;
if (value.dtype().get_lanes_or_vscale_factor() !=
op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var;
return LetStmt(new_var, value, this->VisitStmt(op->body));
} else {
let_binding_[op->var] = op->var;
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
return LetStmt(op->var, value, body);
}
}
}
// Allocate
Stmt VisitStmt_(const AllocateNode *op) final {
// Mutate the condition
PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
}
for (const auto &extent : op->extents) {
PrimExpr new_ext = this->VisitExpr(extent);
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
}
}
return GetRef<Stmt>(op);
}
// scalarize the statment
Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->dtype);
stmt = Substitute(stmt, {{var_, idx}});
return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, stmt);
}
// ProducerStore
Stmt VisitStmt_(const ProducerStoreNode *op) final {
LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc";
}
private:
// analyzer
arith::Analyzer analyzer_;
// deep equal
ExprDeepEqual deep_equal_;
// variable to be replaced
Var var_;
// the lanes.
PrimExpr var_lanes_;
// ramp representing the var.
PrimExpr ramp_;
// flag to mark requirment of scalarization.
bool need_scalarize_{false};
// Let binding
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
// vectorizable property
OpAttrMap<TVectorizable> op_vectorizable_ =
Op::GetAttrMap<TVectorizable>("TVectorizable");
// mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement.
Array<PrimExpr> MutateArray(Array<PrimExpr> arr, int *p_lanes) {
if (arr.size() == 0)
return arr;
int &lanes = *p_lanes;
bool changed = false;
std::vector<PrimExpr> new_arr(arr.size());
for (size_t i = 0; i < arr.size(); i++) {
PrimExpr old_elem = arr[i];
PrimExpr new_elem = this->VisitExpr(old_elem);
if (!new_elem.same_as(old_elem))
changed = true;
new_arr[i] = new_elem;
lanes = std::max(lanes, new_elem.dtype().lanes());
}
for (size_t i = 0; i < arr.size(); ++i) {
if (new_arr[i].dtype().lanes() != lanes) {
new_arr[i] = BroadcastTo(new_arr[i], lanes, false);
changed = true;
}
}
if (!changed)
return arr;
return Array<PrimExpr>(new_arr);
}
template <typename TOp, typename T> PrimExpr BinaryVec(const T *op) {
static_assert(std::is_same<typename TOp::ContainerType, T>::value,
"constraint");
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(a_lanes, b_lanes);
bool is_scalable =
a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return TOp(BroadcastTo(a, lanes, is_scalable),
BroadcastTo(b, lanes, is_scalable));
}
}
template <typename T, typename FCompute>
PrimExpr AddSubVec(const T *op, FCompute fcompute) {
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
int lanes = std::max(a_lanes, b_lanes);
if (lanes != 1) {
const RampNode *b_ramp = b.as<RampNode>();
const RampNode *a_ramp = a.as<RampNode>();
if (a.dtype().is_scalar() && b_ramp) {
return Ramp(
fcompute(a, b_ramp->base),
fcompute(make_zero(b_ramp->stride.dtype()), b_ramp->stride),
b_ramp->lanes);
}
if (b.dtype().is_scalar() && a_ramp) {
return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
bool is_scalable =
a.dtype().is_scalable_vector() || b.dtype().is_scalable_vector();
return fcompute(BroadcastTo(a, lanes, is_scalable),
BroadcastTo(b, lanes, is_scalable));
}
}
};
class LoopVectorizer : public StmtMutator {
public:
Stmt VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kVectorized) {
auto *extent_as_int = op->extent.as<IntImmNode>();
if (!extent_as_int || extent_as_int->value < 1) {
bool is_scalable_expr =
CheckContains::ExprContains(op->extent, arith::IsVScaleCall);
ICHECK(is_scalable_expr && arith::TargetHasSVE())
<< "Failed to vectorize loop with extent " << op->extent
<< " for target " << Target::Current();
}
ICHECK(is_zero(op->min));
return TLVectorizer(op->loop_var, op->extent)(op->body);
} else {
return StmtMutator::VisitStmt_(op);
}
}
};
class VectorizeSkipper : public StmtMutator {
public:
Stmt VisitStmt_(const ForNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
if (op->kind == ForKind::kVectorized) {
return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body);
} else {
return stmt;
}
}
};
Stmt SkipVectorize(Stmt stmt) { return VectorizeSkipper()(std::move(stmt)); }
tvm::transform::Pass VectorizeLoop(bool enable_vectorize = true) {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
if (enable_vectorize) {
n->body = tvm::tl::LoopVectorizer()(std::move(n->body));
} else {
n->body = tvm::tl::VectorizeSkipper()(std::move(n->body));
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.VectorizeLoop", {});
}
TVM_REGISTER_GLOBAL("tl.transform.VectorizeLoop").set_body_typed(VectorizeLoop);
} // namespace tl
} // namespace tvm
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa
import tilelang
from tilelang import tvm as tvm
import tilelang.testing
from tvm import te
from tvm.script import ir as I
from tilelang import language as T
import pytest
simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu")
sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve")
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_loop(extent, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((16,), "float32")):
for j in T.vectorized(0, extent):
A[j] = 1
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((16,), "float32")):
A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent)
with tvm.target.Target(target):
mod = tilelang.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
def test_vectorize_vector():
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32x4", name="A")
with ib.for_range(0, n) as i:
with ib.for_range(0, 4, kind="vectorize") as j:
A[j] = tvm.tir.const(1, A.dtype)
stmt = ib.get()
assert isinstance(stmt.body, tvm.tir.For)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tilelang.transform.VectorizeLoop()(mod)["main"].body
assert isinstance(stmt, tvm.tir.For)
assert not isinstance(stmt.body, tvm.tir.For)
assert len(stmt.body.indices) == 1
assert isinstance(stmt.body.indices[0], tvm.tir.Ramp)
assert isinstance(stmt.body.value, tvm.tir.Broadcast)
def test_vectorize_vector_scalable_error():
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
for j in T.vectorized(T.vscale() * 4):
A[j * 4:j * 4 + 4] = T.Broadcast(T.float32(1), 4)
error_msg = f"Creating scalable vectors from existing vectors is not supported."
with tvm.target.Target(sve_target):
with pytest.raises(tvm.error.InternalError, match=error_msg):
tilelang.transform.VectorizeLoop()(Module)
def test_vectorize_vector_scalable_error2():
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((25,), "float32xvscalex4")):
for j in T.vectorized(4):
A[j] = T.Broadcast(T.float32(1), T.vscale() * 4)
error_msg = f"Vectorizing over scalable buffer elements is not supported in vectorizer."
with pytest.raises(tvm.error.InternalError, match=error_msg):
tilelang.transform.VectorizeLoop()(Module)
def test_vectorize_vector_scalable_error3():
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
for j in T.vectorized(4):
A[j * T.vscale() * 4:j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast(
T.float32(1),
T.vscale() * 4)
error_msg = f"Vectorizing over existing scalable vectors is not supported."
with pytest.raises(tvm.error.InternalError, match=error_msg):
with tvm.target.Target(sve_target):
tilelang.transform.VectorizeLoop()(Module)
def test_vectorize_vector_scalable_error4():
@I.ir_module
class Module:
@T.prim_func(private=True)
def main(A: T.Buffer((25,), "float32")):
for j in T.vectorized(T.vscale() * 4):
A[j * T.vscale() * 4:j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast(
T.float32(1),
T.vscale() * 4)
error_msg = f"Creating scalable vectors from existing vectors is not supported."
with pytest.raises(tvm.error.InternalError, match=error_msg):
with tvm.target.Target(sve_target):
tilelang.transform.VectorizeLoop()(Module)
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_with_if(extent, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32):
for i in T.vectorized(extent):
if x < n:
A[i] = A[i] + T.float32(1)
else:
if i < n:
A[i] = T.float32(2)
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32):
if x < n:
A[T.Ramp(0, 1,
extent)] = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent)
else:
for i_s in range(extent):
if i_s < n:
A[i_s] = T.float32(2)
with tvm.target.Target(target):
mod = tilelang.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
def test_vectorize_with_if_cond_int64():
m = te.size_var("m", dtype="int64")
A = te.placeholder((m,), name="A", dtype="float32")
B = te.compute((m,), lambda i: te.if_then_else(i < 2, A[i], A[i] * 2), name="B")
s = te.create_schedule(B.op)
x, y = s[B].split(B.op.axis[0], factor=4)
s[B].vectorize(y)
f = tvm.build(s, [A, B], "llvm")
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_let(extent, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
for i in T.vectorized(extent):
v = A[i] + T.float32(1)
A[i] = v + T.float32(2)
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent)
A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent)
with tvm.target.Target(target):
mod = tilelang.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)])
def test_vectorize_with_le_cond(extent, target):
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
with ib.for_range(0, extent, kind="vectorize") as i:
with ib.if_scope(i <= n):
A[i] = A[i] + 1
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
with tvm.target.Target(target):
stmt = tilelang.transform.VectorizeLoop()(mod)["main"].body
# Check that the loop wasn't vectorised
assert isinstance(stmt, tvm.tir.For)
@pytest.mark.parametrize("extent, target", [(4, simple_target), (tvm.tir.vscale() * 4, sve_target)])
def test_vectorize_with_ge_cond(extent, target):
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
with ib.for_range(0, extent, kind="vectorize") as i:
with ib.if_scope(i >= n):
A[i] = A[i] + 1
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
with tvm.target.Target(target):
stmt = tilelang.transform.VectorizeLoop()(mod)["main"].body
# Check that the loop wasn't vectorised
assert isinstance(stmt, tvm.tir.For)
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_if_then_else_scalarize(extent, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
for i in T.vectorized(extent):
A[i] = T.if_then_else(i > 0, A[i] + T.float32(1), A[i])
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32")):
for i_s in range(extent):
A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s])
with tvm.target.Target(target):
mod = tilelang.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_if_then_else_vector(extent, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), n: T.int32):
for i in range(n):
for j in T.vectorized(extent):
A[i * extent + j] = T.if_then_else(i > 0, A[i * extent + j], 0)
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), n: T.int32):
for i in range(n):
A[T.Ramp(i * extent, 1, extent)] = T.if_then_else(i > 0,
A[T.Ramp(i * extent, 1, extent)],
T.Broadcast(0, extent))
with tvm.target.Target(target):
mod = tilelang.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
def test_vectorize_while_fail():
"""A while loop inside a vectorized loop should fail."""
n = 64
num_iter = 10
def test_ir(A, B, C):
ib = tvm.tir.ir_builder.create()
n = C.shape[0]
A = ib.buffer_ptr(A)
B = ib.buffer_ptr(B)
C = ib.buffer_ptr(C)
i = ib.allocate("int32", (1,), name="i", scope="local")
i[0] = 0
with ib.for_range(0, n) as j:
C[j] = 0.0
with ib.for_range(0, n, kind="vectorize") as j:
with ib.while_loop(i[0] < num_iter):
C[j] += A[j] + B[j]
i[0] += 1
return ib.get()
dtype = "float32"
A = te.placeholder((n,), name="A", dtype=dtype)
B = te.placeholder((n,), name="B", dtype=dtype)
C = te.extern(
(n,),
[A, B],
lambda ins, outs: test_ir(ins[0], ins[1], outs[0]),
name="while_vectorize",
dtype=dtype,
)
s = te.create_schedule(C.op)
try:
tvm.lower(s, [A, B, C], "llvm")
assert False
except tvm.error.TVMError as e:
error_msg = str(e).split("\n")[-1]
expected = "A while loop inside a vectorized loop not supported"
assert expected in error_msg
def test_vectorize_dtype_mismatch():
n = tvm.tir.IntImm("int64", 4)
A = te.compute((n,), lambda i: tvm.tir.IntImm("int64", 2**31 - 1) + i, name="A")
s = te.create_schedule(A.op)
s[A].vectorize(A.op.axis[0])
tvm.lower(s, [A], "llvm", simple_mode=True)
@pytest.mark.parametrize(
"extent, vec_str, target",
[(16, "float32x16", simple_target), (T.vscale() * 8, "float32xvscalex8", sve_target)],
)
def test_vectorize_with_reinterpret(extent, vec_str, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
for i in T.vectorized(0, extent):
B[i] = T.reinterpret("float32", A[i])
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, extent)])
with tvm.target.Target(target):
mod = tilelang.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
@pytest.mark.parametrize(
"op",
(
T.Mul,
T.Add,
T.Sub,
T.Div,
T.Mod,
T.FloorDiv,
T.FloorMod,
T.Min,
T.Max,
T.EQ,
T.LT,
T.LE,
T.GE,
T.GT,
T.NE,
),
)
def test_vectorize_binary(op, extent, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
for j in T.vectorized(extent):
A[j] = op(T.float32(3), B[j])
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), B[T.Ramp(0, 1, extent)])
with tvm.target.Target(target):
mod = tilelang.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
@pytest.mark.parametrize("op", (T.And, T.Or))
def test_vectorize_logical(op, extent, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")):
for j in T.vectorized(extent):
A[j] = op(T.bool(1), B[j])
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")):
A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), B[T.Ramp(0, 1, extent)])
with tvm.target.Target(target):
mod = tilelang.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
@pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)])
def test_vectorize_select(extent, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
for j in T.vectorized(extent):
A[j] = T.Select(T.bool(True), A[j], B[j])
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
A[T.Ramp(0, 1, extent)] = T.Select(
T.Broadcast(T.bool(True), extent),
A[T.Ramp(0, 1, extent)],
B[T.Ramp(0, 1, extent)],
)
with tvm.target.Target(target):
mod = tilelang.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
@pytest.mark.parametrize(
"extent, vec_str, target",
[(4, "int32x4", simple_target), (T.vscale() * 4, "int32xvscalex4", sve_target)],
)
def test_vectorize_cast(extent, vec_str, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
for j in T.vectorized(extent):
A[j] = T.Cast("int32", B[j])
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)])
with tvm.target.Target(target):
mod = tilelang.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
def test_illegal_extent():
@I.ir_module(check_well_formed=False)
class Mod:
@T.prim_func
def main(A: T.Buffer((25,), "int32")):
n = T.Var("n", dtype="int32")
for j in T.vectorized(n):
A[j] = 3
error_msg = f"Failed to vectorize loop with extent n for target \\(nullptr\\)"
with pytest.raises(tvm.error.InternalError, match=error_msg):
tilelang.transform.VectorizeLoop()(Mod)
def test_illegal_vscale_in_non_sve_compilation():
@I.ir_module
class Mod:
@T.prim_func
def main(A: T.Buffer((16,), "float32")):
for j in T.vectorized(0, 4 * T.vscale()):
A[j] = 13
msg = (f"Failed to vectorize loop with extent T.vscale\\(\\) \\* 4 for target "
f"llvm -keys=cpu -mtriple=x86_64-linux-gnu")
with tvm.target.Target(simple_target):
with pytest.raises(tvm.error.InternalError, match=msg):
tilelang.transform.VectorizeLoop()(Mod)
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
from .lower import lower # noqa: F401 from .lower import lower, is_device_call # noqa: F401
...@@ -144,7 +144,6 @@ def lower( ...@@ -144,7 +144,6 @@ def lower(
mod = tl.transform.LegalizeSafeMemoryAccess()(mod) mod = tl.transform.LegalizeSafeMemoryAccess()(mod)
# Inject Simplify to remove the duplicated conditions # Inject Simplify to remove the duplicated conditions
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
mod = tir.transform.VectorizeLoop()(mod)
# which may be introduced by the LegalizeSafeMemoryAccess # which may be introduced by the LegalizeSafeMemoryAccess
if target.arch == "sm_90": if target.arch == "sm_90":
...@@ -163,6 +162,7 @@ def lower( ...@@ -163,6 +162,7 @@ def lower(
mod = tir.transform.FlattenBuffer()(mod) mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
mod = tl.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod) mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod) mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod) mod = tir.transform.RenormalizeSplitPattern()(mod)
......
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
"""The language interface for tl programs.""" """The language interface for tl programs."""
from typing import Optional from typing import Optional
from .parser import * # from .parser import *
# from tvm.script.parser.tir import * # now is fully compatible with the upstream
# tir script
from tvm.script.parser.tir import *
from tilelang.layout import Layout, Fragment # noqa: F401 from tilelang.layout import Layout, Fragment # noqa: F401
from .parallel import Parallel # noqa: F401 from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401 from .pipeline import Pipelined # noqa: F401
......
...@@ -28,7 +28,12 @@ from tvm.ir import GlobalVar, PrimType ...@@ -28,7 +28,12 @@ from tvm.ir import GlobalVar, PrimType
from tvm.tir import Buffer, IterVar, PrimExpr, Var from tvm.tir import Buffer, IterVar, PrimExpr, Var
from tvm.script.ir_builder import ir as I from tvm.script.ir_builder import ir as I
from .. import ast as T from tvm.script.ir_builder import tir as T
# May rewrite some register functions
# if we use our own registration
# from .. import ast as T
from tvm.script.ir_builder.base import IRBuilder from tvm.script.ir_builder.base import IRBuilder
from tvm.script.ir_builder.base import IRBuilderFrame as Frame from tvm.script.ir_builder.base import IRBuilderFrame as Frame
from tvm.script.parser._core import Parser, dispatch, doc from tvm.script.parser._core import Parser, dispatch, doc
......
...@@ -186,3 +186,14 @@ def AnnotateDeviceRegions(): ...@@ -186,3 +186,14 @@ def AnnotateDeviceRegions():
The result pass The result pass
""" """
return _ffi_api.AnnotateDeviceRegions() # type: ignore return _ffi_api.AnnotateDeviceRegions() # type: ignore
def VectorizeLoop(enable_vectorize: bool = True):
"""VectorizeLoop
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.VectorizeLoop(enable_vectorize) # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment