"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "58a51b5bcf2e06a68c153d3631b14104b7a71130"
Commit 8c5b1341 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Support `T.Parallel` with local register assignment (#395)

* make it python 3.8- happy

* [Enhancement] Improve loop partitioning and vectorization logic in layout inference and loop vectorization

- Enhanced the VisitStmt_ method to support local buffer handling in parallel loops, allowing for register usage without explicit thread binding.
- Updated loop vectorization logic to simplify expressions and ensure accurate vector size calculations, improving performance and clarity in the vectorization process.

* lint fix
parent 192a3995
...@@ -529,16 +529,37 @@ private: ...@@ -529,16 +529,37 @@ private:
Stmt VisitStmt_(const ForNode *op) final { Stmt VisitStmt_(const ForNode *op) final {
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op)); For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(GetRef<For>(op))) { if (result_.for_map.count(GetRef<For>(op))) {
auto loop_layout = result_.for_map[GetRef<For>(op)]; auto root = GetRef<For>(op);
if (!skip_thread_partition_) { // This check is a workaround to support T.Parallel for local buffers.
// If none thread bindings are provided, partition the loop // For example:
// for i in T.Parallel(1024):
// A_local[i] = A_global[i]
// Here, A_local is a register-local buffer held independently by each
// thread, so explicit thread binding is not required.
//
// We use PostOrderVisit to detect whether the buffer store targets a
// "local" buffer, which indicates register usage and justifies skipping
// thread binding.
bool is_register_store = false;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() == "local") {
is_register_store = true;
}
}
});
bool parallel_loop = !is_register_store && !skip_thread_partition_;
if (parallel_loop) {
auto loop_layout = result_.for_map[root];
for_node = for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout); PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
} }
// If none thread bindings are provided, partition the loop
for_node = VectorizeLoop(for_node); for_node = VectorizeLoop(for_node);
if (result_.predicate_map.count(GetRef<For>(op))) { if (result_.predicate_map.count(root) && parallel_loop) {
return IfThenElse(result_.predicate_map[GetRef<For>(op)], for_node); return IfThenElse(result_.predicate_map[root], for_node);
} else { } else {
return for_node; return for_node;
} }
......
...@@ -53,8 +53,6 @@ public: ...@@ -53,8 +53,6 @@ public:
int Plan(const For &node) { int Plan(const For &node) {
this->operator()(node); this->operator()(node);
// Always Enable vectorization
// if (!has_nonlocal_memory_access_) return 1;
return vector_size_; return vector_size_;
} }
...@@ -127,14 +125,12 @@ private: ...@@ -127,14 +125,12 @@ private:
} }
// so we should disable this GCD optimization // so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value); max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back(); auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim); auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block // when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize // conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) { if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff); max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base); auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
// If gcd_base is equal to the last dimension, // If gcd_base is equal to the last dimension,
// we should analyze the second-to-last dimension // we should analyze the second-to-last dimension
...@@ -142,7 +138,6 @@ private: ...@@ -142,7 +138,6 @@ private:
if (gcd_base < Downcast<IntImm>(last_dim)->value) { if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base; max_vector_size = gcd_base;
} }
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_); vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
PrimExpr elem_offset = 0; PrimExpr elem_offset = 0;
...@@ -243,12 +238,13 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, ...@@ -243,12 +238,13 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
return false; return false;
Var v0("v0"), v1("v1"); Var v0("v0"), v1("v1");
analyzer->Bind(v0, Range(0, target_vectorized_size)); analyzer->Bind(v0, Range(0, target_vectorized_size));
analyzer->Bind(v1, Range(0, FloorDiv(iter_var_size, target_vectorized_size))); analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv(
iter_var_size, target_vectorized_size))));
PrimExpr expr_transformed = analyzer->Simplify( PrimExpr expr_transformed = analyzer->Simplify(
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}})); Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size)); Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed); PrimExpr expr_vectorized =
analyzer->Simplify(vectorizer.VisitExpr(expr_transformed));
auto ramp_node = expr_vectorized.as<RampNode>(); auto ramp_node = expr_vectorized.as<RampNode>();
if (!ramp_node) { if (!ramp_node) {
// Broadcast value // Broadcast value
......
...@@ -72,7 +72,10 @@ class KernelParam: ...@@ -72,7 +72,10 @@ class KernelParam:
Returns: Returns:
bool: True if parameter is an unsigned integer type, False otherwise bool: True if parameter is an unsigned integer type, False otherwise
""" """
return str(self.dtype).removeprefix("torch.").startswith("uint") dtype_str = str(self.dtype)
if dtype_str.startswith("torch."):
dtype_str = dtype_str[6:]
return dtype_str.startswith("uint")
def is_float8(self) -> bool: def is_float8(self) -> bool:
""" """
...@@ -81,7 +84,10 @@ class KernelParam: ...@@ -81,7 +84,10 @@ class KernelParam:
Returns: Returns:
bool: True if parameter is a float8 type, False otherwise bool: True if parameter is a float8 type, False otherwise
""" """
return str(self.dtype).removeprefix("torch.").startswith("float8") dtype_str = str(self.dtype)
if dtype_str.startswith("torch."):
dtype_str = dtype_str[6:]
return dtype_str.startswith("float8")
def is_boolean(self) -> bool: def is_boolean(self) -> bool:
""" """
...@@ -90,7 +96,8 @@ class KernelParam: ...@@ -90,7 +96,8 @@ class KernelParam:
Returns: Returns:
bool: True if parameter is a boolean type, False otherwise bool: True if parameter is a boolean type, False otherwise
""" """
return str(self.dtype).removeprefix("torch.").startswith("bool") dtype_str = str(self.dtype)
return dtype_str[6:] if dtype_str.startswith("torch.") else dtype_str.startswith("bool")
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment