Commit 310fea95 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Revert] Revert modifications for pass FlattenBuffer (#385)

* fix

* Update submodule TVM to latest commit and enhance FlattenBuffer pass in TileLang engine. Added boolean handling in buffer loading and improved address_of detection in flattening logic.

* lint fix
parent abaacde5
Subproject commit fbd82e919c01238a5ac78a4d5c66b3da80161255 Subproject commit 6df0b88a90cd7a931ba171b57f0cf41d4cbfa2fe
...@@ -210,8 +210,30 @@ private: ...@@ -210,8 +210,30 @@ private:
} }
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
bool load_returns_bool = (op->dtype == DataType::Bool());
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return VisitBufferAccess(load); load = VisitBufferAccess(load);
// Handle casts from dtype of the backing array to value's dtype.
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (load_returns_bool && !under_address_of) {
ICHECK_EQ(load->buffer->dtype, DataType::Int(8))
<< "Expected int8 backing array for boolean tensor";
load.CopyOnWrite()->dtype = DataType::Int(8);
return tvm::cast(DataType::Bool(), load);
} else {
return std::move(load);
}
}
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::address_of())) {
under_address_of = true;
auto result = StmtExprMutator::VisitExpr_(op);
under_address_of = false;
return result;
}
return StmtExprMutator::VisitExpr_(op);
} }
Array<PrimExpr> GetSimplifiedElemOffset(const Buffer &buffer, Array<PrimExpr> GetSimplifiedElemOffset(const Buffer &buffer,
...@@ -260,6 +282,8 @@ private: ...@@ -260,6 +282,8 @@ private:
return BufferRegion(flattened_buf, flattened_ranges); return BufferRegion(flattened_buf, flattened_ranges);
} }
/*! \brief Whether the current buffer is under address_of */
bool under_address_of = false;
/*! \brief Map of buffers being remapped. */ /*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual>
buffer_remap_; buffer_remap_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment