"examples/vscode:/vscode.git/clone" did not exist on "226aaef9f43c24df758df9791568d9c49f4c5a6e"
Commit 137dab67 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Typo] Remove debug print (#373)

* [Enhancement] Add variable check in GlobalMemChecker for safe memory access validation

- Introduced a check in the GlobalMemChecker to determine if the index used in memory access has any variable components, enhancing the safety of memory access validation.
- Updated the condition handling in store operations to ensure that only boolean conditions are processed, improving type safety and error handling in memory operations.

* [Refactor] Rename VecAllocAccess to TLVecAllocAccess and enhance buffer access handling

- Renamed the `VecAllocAccess` class to `TLVecAllocAccess` for clarity in its purpose.
- Improved the handling of buffer access by mutating extents and rewriting access in the body, ensuring compatibility with vectorized operations.
- Added a TODO comment to suggest moving this pass to occur before StorageFlatten/FlattenBuffer for better optimization.
- Introduced a print statement in the phase optimization process for debugging purposes.

* lint fix
parent b1e6b27f
...@@ -123,9 +123,9 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { ...@@ -123,9 +123,9 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
// The same principle applies when using one thread to simulate multiple // The same principle applies when using one thread to simulate multiple
// context. // context.
// //
class VecAllocAccess : public StmtExprMutator { class TLVecAllocAccess : public StmtExprMutator {
public: public:
VecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes) TLVecAllocAccess(const VarNode *buf, Var var, PrimExpr var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {} : buf_(buf), var_(var), var_lanes_(var_lanes) {}
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
...@@ -182,16 +182,6 @@ private: ...@@ -182,16 +182,6 @@ private:
buffer_map_[buf.get()] = buf; buffer_map_[buf.get()] = buf;
} }
// Extend the last index by the number of lanes in the vectorized
// variable.
Array<PrimExpr> indices = node->indices;
indices.Set(
indices.size() - 1,
analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_));
auto writer = node.CopyOnWrite();
writer->buffer = buf;
writer->indices = indices;
return node; return node;
} }
...@@ -688,6 +678,8 @@ public: ...@@ -688,6 +678,8 @@ public:
return Scalarize(GetRef<Stmt>(op)); return Scalarize(GetRef<Stmt>(op));
} }
// Mutate the extents
Array<PrimExpr> extents;
for (const auto &extent : op->extents) { for (const auto &extent : op->extents) {
PrimExpr new_ext = this->VisitExpr(extent); PrimExpr new_ext = this->VisitExpr(extent);
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) { if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
...@@ -695,8 +687,23 @@ public: ...@@ -695,8 +687,23 @@ public:
<< op->buffer_var->name_hint; << op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op)); return Scalarize(GetRef<Stmt>(op));
} }
extents.push_back(new_ext);
} }
return GetRef<Stmt>(op);
// TODO(Lunderberg): Move this pass to be prior to
// StorageFlatten/FlattenBuffer. That will allow this pass to be
// implemented as adding a new buffer dimension, which is later
// flattened.
// Extend the least significant dimension by a factor of
// var_lanes_. Typically, this will be a 1-d index into a flat
// memory space.
extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_);
// Rewrite access to the buffer in the body.
Stmt body =
TLVecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
body = this->VisitStmt(body);
return Allocate(op->buffer_var, op->dtype, extents, condition, body);
} }
// scalarize the statment // scalarize the statment
......
...@@ -87,7 +87,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -87,7 +87,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod) mod = tir.transform.MergeSharedMemoryAllocations()(mod)
print(mod)
mod = tilelang.transform.MakePackedAPI()(mod) mod = tilelang.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod) mod = tir.transform.LowerDeviceKernelLaunch()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment