Unverified Commit 65c4711f authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Resolve mixed stride dtype issue (inconsistent int32/int64 values) (#1119)



* fix int32 dtype issue

* lint fix

* lint

* lint fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent 50e789dd
...@@ -4,7 +4,7 @@ ExtraArgs: ['-v'] ...@@ -4,7 +4,7 @@ ExtraArgs: ['-v']
FormatStyle: file FormatStyle: file
UseColor: true UseColor: true
WarningsAsErrors: '*' WarningsAsErrors: '*'
ExcludeHeaderFilterRegex: '^(3rdparty|tvm)/.*$' HeaderFilterRegex: '^(?!.*(?:/|^)(3rdparty|tvm)/).*'
# NOTE: there must be no spaces before the '-', so put the comma last. # NOTE: there must be no spaces before the '-', so put the comma last.
Checks: >- Checks: >-
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file arg_binder.cc
* \brief Helper utility to match and bind arguments.
*/
#include "arg_binder.h"
#include <tvm/runtime/device_api.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <sstream>
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond,
const std::string &arg_name, std::vector<Stmt> *asserts) {
PrimExpr scond = ana->Simplify(cond);
if (is_zero(scond)) {
LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", "
<< " on argument " << arg_name;
}
if (!is_one(scond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint: " << cond;
asserts->emplace_back(AssertStmt(scond, StringImm(os.str()), Evaluate(0)));
}
}
bool ArgBinder::Bind_(const PrimExpr &arg, const PrimExpr &value,
const std::string &arg_name, bool with_lets) {
ICHECK_EQ(arg.dtype(), value.dtype()) << "arg " << arg << " value " << value;
if (const VarNode *v = arg.as<VarNode>()) {
auto it = def_map_->find(v);
if (it == def_map_->end()) {
Var v_arg = Downcast<Var>(arg);
defs_.emplace_back(v_arg);
if (with_lets) {
(*def_map_)[v] = arg;
init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0)));
} else {
(*def_map_)[v] = value;
}
return true;
} else {
BinderAddAssert(&analyzer_, it->second == value, arg_name, &asserts_);
}
} else {
BinderAddAssert(&analyzer_, arg == value, arg_name, &asserts_);
}
return false;
}
void ArgBinder::Bind(const PrimExpr &arg, const PrimExpr &value,
const std::string &arg_name, bool with_let) {
Bind_(arg, value, arg_name, with_let);
}
void ArgBinder::BindArray(const Array<PrimExpr> &arg,
const Array<PrimExpr> &value,
const std::string &arg_name) {
ICHECK_EQ(arg.size(), value.size())
<< "Argument " << arg_name << " array size mismatch";
for (size_t i = 0; i < arg.size(); ++i) {
std::ostringstream os;
os << arg_name << "[" << i << "]";
this->Bind(arg[i], value[i], os.str());
}
}
void ArgBinder::BindBuffer(const Buffer &arg, const Buffer &value,
const std::string &arg_name, bool fuzzy_match) {
ICHECK_EQ(arg.scope(), value.scope())
<< "Argument " << arg_name << " Buffer bind scope mismatch";
ICHECK_EQ(arg->dtype, value->dtype)
<< "Argument " << arg_name << " Buffer bind data type mismatch";
if (value->data_alignment % arg->data_alignment != 0) {
LOG(WARNING) << "Trying to bind buffer to another one with lower alignment "
"requirement "
<< " required_alignment=" << arg->data_alignment
<< ", provided_alignment=" << value->data_alignment;
}
if (value->elem_offset.defined()) {
// bind pointer and offset.
if (is_zero(arg->elem_offset)) {
ICHECK(is_zero(value->elem_offset))
<< "Trying to bind a Buffer with offset into one without offset "
<< " required elem_offset=" << arg->elem_offset
<< ", provided elem_offset=" << value->elem_offset;
}
this->Bind(arg->data, value->data, arg_name + ".data");
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset",
false)) {
if (arg->offset_factor > 1) {
PrimExpr offset = value->elem_offset;
PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero,
arg_name + ".elem_offset", &asserts_);
}
}
}
if (arg->shape.size() < value->shape.size()) {
ICHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
size_t diff = value->shape.size() - arg->shape.size();
for (size_t i = 0; i < diff; ++i) {
ICHECK(is_one(analyzer_.Simplify(value->shape[i])))
<< "Argument " << arg_name << " shape mismatch" << arg->shape
<< " vs " << value->shape;
}
for (size_t i = 0; i < arg->shape.size(); ++i) {
std::ostringstream os;
os << arg_name << ".shape[" << i << "]";
this->Bind(arg->shape[i], value->shape[i + diff], os.str());
}
if (!value->strides.empty()) {
ICHECK_EQ(arg->strides.size(), arg->shape.size());
ICHECK_EQ(value->strides.size(), value->shape.size());
for (size_t i = 0; i < arg->strides.size(); ++i) {
std::ostringstream os;
os << arg_name << ".strides[" << i << "]";
this->Bind(arg->strides[i], value->strides[i + diff], os.str());
}
}
} else {
this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides");
}
}
inline PrimExpr TVMArrayGet(DataType t, Var arr,
builtin::TVMStructFieldKind kind) {
return TVMStructGet(t, arr, 0, kind);
}
void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
const PrimExpr &device_id, const Var &handle,
const std::string &arg_name) {
const DataType tvm_shape_type = DataType::ShapeIndex();
const DataType tvm_ndim_type = DataType::Int(32);
const Stmt nop = Evaluate(0);
init_nest_.emplace_back(AssertStmt(
!Call(DataType::Bool(), builtin::isnullptr(), {handle}),
StringImm(arg_name + " is expected to have non-NULL DLTensor* pointer"),
nop));
// dimension checks
PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
// Helper functions for shape/stride name formatting
auto shape_handle_name = [&]() { return arg_name + ".shape"; };
auto stride_handle_name = [&]() { return arg_name + ".strides"; };
auto array_element_name = [&](const std::string &arr_name, size_t k) {
std::stringstream ss;
ss << arr_name << '[' << k << ']';
return ss.str();
};
auto shape_element_name = [&](size_t k) {
return array_element_name(shape_handle_name(), k);
};
auto stride_element_name = [&](size_t k) {
return array_element_name(stride_handle_name(), k);
};
PrimExpr a_ndim =
make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
std::ostringstream ndim_err_msg;
ndim_err_msg << arg_name << ".ndim is expected to equal "
<< buffer->shape.size();
auto msg = StringImm(ndim_err_msg.str());
init_nest_.emplace_back(AssertStmt(a_ndim == v_ndim, msg, nop));
// type checks
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype;
PrimExpr cond =
(TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode) ==
IntImm(DataType::UInt(8), buffer->dtype.code()) &&
TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits) ==
IntImm(DataType::UInt(8), buffer->dtype.bits()) &&
TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes) ==
IntImm(DataType::UInt(16), buffer->dtype.lanes()));
if (!(buffer->dtype == DataType::Int(1) ||
buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4))) {
auto type_msg = StringImm(type_err_msg.str());
asserts_.emplace_back(AssertStmt(cond, type_msg, nop));
}
// shape field
Buffer buf_shape =
decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())},
tvm_shape_type, shape_handle_name());
Var v_shape(shape_handle_name(), DataType::Handle());
def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
init_nest_.emplace_back(LetStmt(
buf_shape->data,
TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape), nop));
init_nest_.emplace_back(DeclBuffer(buf_shape, nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
if (buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) ||
buffer->dtype == DataType::Int(1)) {
break;
}
Bind_(buffer->shape[k],
cast(buffer->shape[k].dtype(),
BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)})),
shape_element_name(k), true);
}
// strides field
Buffer buf_strides =
decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())},
tvm_shape_type, arg_name + ".strides");
def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(LetStmt(
buf_strides->data,
TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop));
init_nest_.emplace_back(DeclBuffer(buf_strides, nop));
PrimExpr v_strides_is_null =
Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data});
if (buffer->strides.empty()) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
PrimExpr expect_stride = make_const(stype, 1);
Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
PrimExpr svalue =
cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
conds.push_back(buffer->shape[k] == 1 || expect_stride == svalue);
expect_stride = expect_stride * buffer->shape[k];
}
std::ostringstream stride_err_msg;
stride_err_msg << stride_handle_name() << ": expected to be compact array";
if (!conds.empty()) {
auto stride_msg = StringImm(stride_err_msg.str());
Stmt check =
AssertStmt(foldl([](PrimExpr a, PrimExpr b,
Span span) { return logical_and(a, b, span); },
const_true(1), conds),
stride_msg, Evaluate(0));
check = IfThenElse(Not(v_strides_is_null), check);
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
PrimExpr stride_from_shape = make_const(buffer->DefaultIndexType(), 1);
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
DataType stride_dtype = buffer->strides[k].dtype();
PrimExpr explicit_stride =
cast(stride_dtype,
BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape);
PrimExpr value = tvm::if_then_else(
v_strides_is_null, stride_from_shape_cast, explicit_stride);
value = tvm::if_then_else(buffer->shape[k] == 1, make_zero(stride_dtype),
value);
Bind_(buffer->strides[k], value, stride_element_name(k), true);
PrimExpr shape_extent = cast(stride_dtype, buffer->shape[k]);
stride_from_shape =
analyzer_.Simplify(stride_from_shape_cast * shape_extent);
}
} else {
PrimExpr stride_from_shape = make_const(buffer->DefaultIndexType(), 1);
for (int k = buffer->strides.size() - 1; k >= 0; k--) {
DataType stride_dtype = buffer->strides[k].dtype();
PrimExpr explicit_stride =
cast(stride_dtype,
BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
PrimExpr shape_stride = cast(
stride_dtype, BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)}));
PrimExpr stride_from_shape_cast = cast(stride_dtype, stride_from_shape);
Bind_(buffer->strides[k],
tvm::if_then_else(v_strides_is_null, stride_from_shape_cast,
explicit_stride),
stride_element_name(k), true);
stride_from_shape =
analyzer_.Simplify(stride_from_shape_cast * shape_stride);
}
}
// Byte_offset field.
int data_bytes = GetVectorBytes(buffer->dtype);
if (const auto *const_offset = buffer->elem_offset.as<IntImmNode>()) {
Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes),
TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
arg_name + ".byte_offset", true);
} else {
if (Bind_(buffer->elem_offset,
cast(buffer->elem_offset.dtype(),
(TVMArrayGet(DataType::UInt(64), handle,
builtin::kArrByteOffset) /
make_const(DataType::UInt(64), data_bytes))),
arg_name + ".elem_offset", true)) {
if (buffer->offset_factor > 1) {
PrimExpr offset = buffer->elem_offset;
PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero,
arg_name + ".elem_offset", &asserts_);
}
}
}
// device info.
Bind_(device_type,
TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType),
arg_name + ".device_type", true);
Bind_(device_id,
TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId),
arg_name + ".device_id", true);
// Data field. Because the validation of the data field may depend
// on a dynamic size defined by the other DLTensor* parameters, this
// field must be generated last.
if (Bind_(buffer->data,
TVMArrayGet(DataType::Handle(), handle, builtin::kArrData),
arg_name + ".data", true)) {
Var vptr(buffer->data);
// Check if the data pointer is NULL. This check is skipped for
// size-0 arrays, since CUDA provides a NULL pointer for size-zero
// allocations.
auto alloc_size = [&]() -> PrimExpr {
PrimExpr product = IntImm(buffer->DefaultIndexType(), 1);
for (const auto &dim : buffer->shape) {
product *= dim;
}
return product;
}();
asserts_.emplace_back(AssertStmt(
alloc_size == 0 ||
!Call(DataType::Bool(), builtin::isnullptr(), {vptr}),
StringImm(arg_name + " is expected to have non-NULL data pointer"),
nop));
def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
// mark alignment of external bufs
init_nest_.emplace_back(
AttrStmt(vptr, tir::attr::storage_alignment,
IntImm(DataType::Int(32), buffer->data_alignment), nop));
}
}
} // namespace tl
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file arg_binder.h
* \brief Helper utility to match and bind arguments.
*/
#ifndef TVM_TL_TRANSFORM_ARG_BINDER_H_
#define TVM_TL_TRANSFORM_ARG_BINDER_H_
#include <tvm/arith/analyzer.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
namespace tl {
using namespace tir;
/*!
* \brief Helper utility to generate match and bind of arguments.
*
* \note There is many places in TVM IR where we need argument bindings.
*
* Consider a function f(tA(shape=var(n)), tB(shape=3), tC(shape=(n+2)).
* Here n is a undefined variable that is decided by the outside, tB imposes
* a constraint such that it can only take tensor with shape 3, tC imposes
* another constraint that it's shape must equals n + 2.
* So if we call it with f(bufferA, bufferB, bufferC), we need to generate
* the following binding sequence:
* - define n = bufferA.shape[0]
* - assert bufferB.shape[0] == 3
* - assert bufferB.shape[1] == n + 3
*
* In general, this is a constraint solving problem. We have simplified
* assumption over the binding declaration, such that we require the variable
* occurred in constraint must be declared in argument list. So it is illegal to
* have signature f(tA(shape=(n+3))) without any argument variable corresponds
* to n, even though it is already enough to derive n from the input argument.
*/
class ArgBinder {
public:
/*!
* \brief Constructor
* \param def_map A definition map that contains definition of known
* variables. ArgBinder will update this def_map when adding new definitions.
*/
explicit ArgBinder(std::unordered_map<const VarNode *, PrimExpr> *def_map)
: def_map_(def_map) {}
/*!
* \brief Try to bind arg to value, generate constraint if necessary.
* \param arg The argument to be binded.
* \param value The target expression value
* \param arg_name argument name.
* \param with_let Whether add lets during bind
*/
void Bind(const PrimExpr &arg, const PrimExpr &value,
const std::string &arg_name, bool with_let = false);
/*!
* \brief Bind array to array
* \param arg The argument to be binded.
* \param value The target expression value
* \param arg_name argument name.
*/
void BindArray(const Array<PrimExpr> &arg, const Array<PrimExpr> &value,
const std::string &arg_name);
/*!
* \brief Bind symbolic buffer to another symbolic buffer
* \param arg The argument to be binded.
* \param value The target expression value
* \param arg_name argument name.
* \param fuzzy_match If enabled, we allow value's dimension to be smaller
* than arg, as long as arg's higher dimensions are of 1.
*/
void BindBuffer(const Buffer &arg, const Buffer &value,
const std::string &arg_name, bool fuzzy_match);
/*!
* \brief Bind symbolic buffer to a DLTensor handle.
* \param buffer The argument buffer to be binded.
* \param device_type The device id to be binded.
* \param device_id The device id to be binded.
* \param handle The DLTensor handle.
* \param arg_name argument name.
*/
void BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
const PrimExpr &device_id, const Var &handle,
const std::string &arg_name);
/*! \return The defs generated in binding. */
const std::vector<Var> &defs() const { return defs_; }
/*! \return The asserts generated in binding
*
* This contains statements that assert the correct value has been
* bound. For example, `binder.Bind(var, expr_1)` will produce an
* entry mapping `var` to `expr_1` in the `binder.defs()`. If
* `binder.Bind(var, expr_2)` is called later, then this will
* produce an assert statemtn that `expr_1 == expr_2`.
*
* Note: Some assert statements produced by BindDLTensor are located
* in `binder.init_nest()`, not within `binder.asserts()`. This is
* deliberate, as some values may require checks prior to
* initialization. (e.g. Intializing `m = dl_tensor->shape[3]`
* requires first asserting that `3 < dl_tensor->ndim`.)
*/
const std::vector<Stmt> &asserts() const { return asserts_; }
/*!
* \brief Initialization nest generated
*
* This contains both variable bindings and any assert statements
* that are required in order to safely produce those variable
* bindings.
*
* \note Variable bindings may be implemented either as a `LetStmt`
* that defines the variable, or as a variable replacement. Any
* bindings implemented as a `LetStmt` will be in the
* initialization list. Any bindings implemented as a variable
* replacement will be stored in the `var_def` map.
*
* A `tir::LetStmt` is usually generated when binding to a
* `DLTensor`. This requires loading values from memory, which
* should only be performed once. If the binding to a
* `DLTensor` were implemented as a variable replacement, it
* would load values from memory once for each usage of the
* variable.
*
* \return The initialization nest generated during binding.
*/
const std::vector<Stmt> &init_nest() const { return init_nest_; }
/*! \return Handle data type of the data */
const Map<Var, PrimExpr> &def_handle_dtype() const {
return def_handle_dtype_;
}
private:
// Internal bind function
bool Bind_(const PrimExpr &arg, const PrimExpr &value,
const std::string &arg_name, bool with_lets);
/*! \brief The definition map, can be uses to substitute */
std::unordered_map<const VarNode *, PrimExpr> *def_map_;
/*! \brief defs generated in the current binder */
std::vector<Var> defs_;
/*! \brief Initialize nest */
std::vector<Stmt> init_nest_;
/*! \brief handle data type in the defintiions */
Map<Var, PrimExpr> def_handle_dtype_;
/*! \brief asserts generated */
std::vector<Stmt> asserts_;
/*! \brief internal analyzer. */
arith::Analyzer analyzer_;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_TRANSFORM_ARG_BINDER_H_
...@@ -262,24 +262,32 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, ...@@ -262,24 +262,32 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
return true; return true;
// Extent must be divisible // Extent must be divisible
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size), PrimExpr target_size_for_iter =
make_const(iter_var_size.dtype(), target_vectorized_size);
PrimExpr target_size_for_expr =
make_const(expr.dtype(), target_vectorized_size);
PrimExpr target_size_for_var =
make_const(var.dtype(), target_vectorized_size);
PrimExpr zero = make_const(var.dtype(), 0);
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter),
0)) 0))
return false; return false;
// The base offset must be divisible // The base offset must be divisible
if (!analyzer->CanProveEqual( if (!analyzer->CanProveEqual(
FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) { FloorMod(Substitute(expr, {{var, zero}}), target_size_for_expr), 0)) {
return false; return false;
} }
// Bind thread range // Bind thread range
Var v0("v0"), v1("v1"); Var v0("v0", var.dtype()), v1("v1", var.dtype());
analyzer->Bind(v0, Range(0, target_vectorized_size)); analyzer->Bind(v0, Range(zero, target_size_for_var));
analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv( analyzer->Bind(v1, Range(zero, analyzer->Simplify(FloorDiv(
iter_var_size, target_vectorized_size)))); iter_var_size, target_size_for_iter))));
PrimExpr expr_transformed = analyzer->Simplify( PrimExpr expr_transformed = analyzer->Simplify(
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}})); Substitute(expr, {{var, v0 + v1 * target_size_for_var}}));
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size)); Vectorizer vectorizer(v0, target_size_for_var);
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed); PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
// This simplify is necessary for thread region specified // This simplify is necessary for thread region specified
......
...@@ -36,7 +36,7 @@ ...@@ -36,7 +36,7 @@
#include <vector> #include <vector>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "tir/transforms/arg_binder.h" #include "arg_binder.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
namespace tvm { namespace tvm {
...@@ -496,7 +496,6 @@ tvm::transform::Pass MakePackedAPI() { ...@@ -496,7 +496,6 @@ tvm::transform::Pass MakePackedAPI() {
func->body)) { func->body)) {
func.CopyOnWrite()->body = body.value(); func.CopyOnWrite()->body = body.value();
} }
func = MakePackedAPI(std::move(func)); func = MakePackedAPI(std::move(func));
if (!func.same_as(orig_func)) { if (!func.same_as(orig_func)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment