Commit e8c2e794 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Remove redundant recursive rewrite rule for FloorDiv in RewriteSimplifier (#408)

* Update TVM submodule and enhance vectorization logic in loop_vectorize.cc

- Updated the TVM submodule to the latest commit.
- Simplified the vectorization process by ensuring that the vectorized expression is simplified after vectorization, improving expression handling.
- Added checks in loop_fusion_utils.h to prevent fusion of loops with non-power-of-2 extents, enhancing robustness in loop transformations.

* lint fix
parent 860d1e59
Subproject commit 742ed56bc08503c86b75bbd2a80e04db40e8600a Subproject commit b56a1df6f3e41c6b6da019786b88d73ee9a0d378
...@@ -112,7 +112,6 @@ private: ...@@ -112,7 +112,6 @@ private:
if (detector.HasFragmentAccess()) { if (detector.HasFragmentAccess()) {
return IRMutatorWithAnalyzer::VisitStmt_(op); return IRMutatorWithAnalyzer::VisitStmt_(op);
} }
while (true) { while (true) {
if (current->kind != ForKind::kParallel) if (current->kind != ForKind::kParallel)
break; break;
...@@ -132,6 +131,21 @@ private: ...@@ -132,6 +131,21 @@ private:
return IRMutatorWithAnalyzer::VisitStmt_(op); return IRMutatorWithAnalyzer::VisitStmt_(op);
} }
// If one of the loop has extent which is not 2^n, we do not fuse
for (auto l : loop_chain) {
PrimExpr extent = l->extent;
// If extent is not a constant integer, we cannot determine if it's power
// of 2
if (!extent.as<IntImmNode>()) {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
int64_t value = extent.as<IntImmNode>()->value;
// Check if value is power of 2: value > 0 and only has one bit set
if (value <= 0 || (value & (value - 1)) != 0) {
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
}
// At this point we have multiple nested parallel loops starting at zero // At this point we have multiple nested parallel loops starting at zero
// We will fuse them all. // We will fuse them all.
PrimExpr fused_extent = make_const(DataType::Int(32), 1); PrimExpr fused_extent = make_const(DataType::Int(32), 1);
......
...@@ -244,11 +244,11 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, ...@@ -244,11 +244,11 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size,
PrimExpr expr_simplified = analyzer->Simplify(expr_transformed); PrimExpr expr_simplified = analyzer->Simplify(expr_transformed);
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size)); Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed); PrimExpr expr_vectorized =
analyzer->Simplify(vectorizer.VisitExpr(expr_transformed));
auto ramp_node = expr_vectorized.as<RampNode>(); auto ramp_node = expr_vectorized.as<RampNode>();
if (!ramp_node) { if (!ramp_node) {
expr_vectorized = analyzer->Simplify(expr_vectorized);
// Broadcast value // Broadcast value
if (expr_vectorized.dtype().lanes() == 1) if (expr_vectorized.dtype().lanes() == 1)
return true; return true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment