"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "9fde0e8e003d894003ebafa5a30160f5fb9dc2c2"
Commit e71c7a17 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Disable legacy vectorization for buffer allocation (#535)

* Refactor OptimizeForTarget function by removing redundant buffer allocation step and cleaning up code

* Removed the PlanAndUpdateBufferAllocationLocation step from the OptimizeForTarget function to streamline the optimization process.
* Cleaned up unnecessary whitespace in the function for improved readability.
* Enhanced the overall clarity and maintainability of the code.

* Refactor AllocateNode handling in vectorize_loop.cc

* Simplified the VisitStmt_ method for AllocateNode by removing the complex extent mutation logic.
* Streamlined the allocation process to directly call the base class method, enhancing code clarity and maintainability.
* Improved overall readability by eliminating unnecessary comments and code related to extent handling.

* Remove `tl_kernel.c` file, eliminating the backward kernel implementation and associated error handling functions. This cleanup enhances code maintainability by removing unused components related to the backward kernel processing.

* Add buffer allocation planning step in OptimizeForTarget function

* Introduced the PlanAndUpdateBufferAllocationLocation step to the OptimizeForTarget function, enhancing the optimization process.
* This addition improves the overall efficiency of buffer allocation during the target optimization phase, ensuring better resource management.
parent 8af5eb77
...@@ -668,6 +668,7 @@ public: ...@@ -668,6 +668,7 @@ public:
} }
} }
} }
// Allocate // Allocate
Stmt VisitStmt_(const AllocateNode *op) final { Stmt VisitStmt_(const AllocateNode *op) final {
// Mutate the condition // Mutate the condition
...@@ -678,32 +679,7 @@ public: ...@@ -678,32 +679,7 @@ public:
return Scalarize(GetRef<Stmt>(op)); return Scalarize(GetRef<Stmt>(op));
} }
// Mutate the extents return StmtMutator::VisitStmt_(op);
Array<PrimExpr> extents;
for (const auto &extent : op->extents) {
PrimExpr new_ext = this->VisitExpr(extent);
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
}
extents.push_back(new_ext);
}
// TODO(Lunderberg): Move this pass to be prior to
// StorageFlatten/FlattenBuffer. That will allow this pass to be
// implemented as adding a new buffer dimension, which is later
// flattened.
// Extend the least significant dimension by a factor of
// var_lanes_. Typically, this will be a 1-d index into a flat
// memory space.
extents.Set(extents.size() - 1, extents[extents.size() - 1] * var_lanes_);
// Rewrite access to the buffer in the body.
Stmt body =
TLVecAllocAccess(op->buffer_var.get(), var_, var_lanes_)(op->body);
body = this->VisitStmt(body);
return Allocate(op->buffer_var, op->dtype, extents, condition, body);
} }
// scalarize the statment // scalarize the statment
......
...@@ -108,7 +108,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -108,7 +108,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod) mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
mod = tir.transform.StorageRewrite()(mod) mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod) mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod) mod = tir.transform.RenormalizeSplitPattern()(mod)
......
...@@ -102,7 +102,8 @@ class LibraryGenerator(object): ...@@ -102,7 +102,8 @@ class LibraryGenerator(object):
raise RuntimeError(f"Compile kernel failed because of {e}") from e raise RuntimeError(f"Compile kernel failed because of {e}") from e
if ret.returncode != 0: if ret.returncode != 0:
raise RuntimeError(f"Compilation Failed! {command}") raise RuntimeError(f"Compilation Failed! {command}"
f"\n {self.lib_code}")
self.srcpath = src.name self.srcpath = src.name
self.libpath = libpath self.libpath = libpath
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment