Commit 137dab67 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Typo] Remove debug print (#373)

* [Enhancement] Add variable check in GlobalMemChecker for safe memory access validation

- Introduced a check in the GlobalMemChecker to determine if the index used in memory access has any variable components, enhancing the safety of memory access validation.
- Updated the condition handling in store operations to ensure that only boolean conditions are processed, improving type safety and error handling in memory operations.

* [Refactor] Rename VecAllocAccess to TLVecAllocAccess and enhance buffer access handling

- Renamed the `VecAllocAccess` class to `TLVecAllocAccess` for clarity in its purpose.
- Improved the handling of buffer access by mutating extents and rewriting access in the body, ensuring compatibility with vectorized operations.
- Added a TODO comment to suggest moving this pass to occur before StorageFlatten/FlattenBuffer for better optimization.
- Introduced a print statement in the phase optimization process for debugging purposes.

* lint fix
parent b1e6b27f
......@@ -123,9 +123,9 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
// The same principle applies when using one thread to simulate multiple
// context.
//
class VecAllocAccess : public StmtExprMutator {
class TLVecAllocAccess : public StmtExprMutator {
public:
VecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes)
TLVecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
......@@ -182,16 +182,6 @@ private:
buffer_map_[buf.get()] = buf;
}
// Extend the last index by the number of lanes in the vectorized
// variable.
Array<PrimExpr> indices = node->indices;
indices.Set(
indices.size() - 1,
analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_));
auto writer = node.CopyOnWrite();
writer->buffer = buf;
writer->indices = indices;
return node;
}
......@@ -688,6 +678,8 @@ public:
return Scalarize(GetRef<Stmt>(op));
}
// Mutate the extents
Array<PrimExpr> extents;
for (const auto &extent : op->extents) {
PrimExpr new_ext = this->VisitExpr(extent);
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
......@@ -695,8 +687,23 @@ public:
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
}
}
return GetRef<Stmt>(op);
extents.push_back(new_ext);
}
// TODO(Lunderberg): Move this pass to be prior to
// StorageFlatten/FlattenBuffer. That will allow this pass to be
// implemented as adding a new buffer dimension, which is later
// flattened.
// Extend the least significant dimension by a factor of
// var_lanes_. Typically, this will be a 1-d index into a flat
// memory space.
extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_);
// Rewrite access to the buffer in the body.
Stmt body =
TLVecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
body = this->VisitStmt(body);
return Allocate(op->buffer_var, op->dtype, extents, condition, body);
}
// scalarize the statment
......
......@@ -87,7 +87,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
print(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment