Commit 310fea95 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Revert] Revert modifications for pass FlattenBuffer (#385)

* fix

* Update submodule TVM to latest commit and enhance FlattenBuffer pass in TileLang engine. Added boolean handling in buffer loading and improved address_of detection in flattening logic.

* lint fix
parent abaacde5
Subproject commit fbd82e919c01238a5ac78a4d5c66b3da80161255
Subproject commit 6df0b88a90cd7a931ba171b57f0cf41d4cbfa2fe
......@@ -210,8 +210,30 @@ private:
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
bool load_returns_bool = (op->dtype == DataType::Bool());
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return VisitBufferAccess(load);
load = VisitBufferAccess(load);
// Handle casts from dtype of the backing array to value's dtype.
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (load_returns_bool && !under_address_of) {
ICHECK_EQ(load->buffer->dtype, DataType::Int(8))
<< "Expected int8 backing array for boolean tensor";
load.CopyOnWrite()->dtype = DataType::Int(8);
return tvm::cast(DataType::Bool(), load);
} else {
return std::move(load);
}
}
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::address_of())) {
under_address_of = true;
auto result = StmtExprMutator::VisitExpr_(op);
under_address_of = false;
return result;
}
return StmtExprMutator::VisitExpr_(op);
}
Array<PrimExpr> GetSimplifiedElemOffset(const Buffer &buffer,
......@@ -260,6 +282,8 @@ private:
return BufferRegion(flattened_buf, flattened_ranges);
}
/*! \brief Whether the current buffer is under address_of */
bool under_address_of = false;
/*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>
buffer_remap_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment