Unverified Commit 36a2b2f3 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Simplify index sign state handling in LegalizeNegativeIndex (#1354)

This commit refines the logic for determining the sign state of indices in the LegalizeNegativeIndex transformation. It prioritizes vector patterns, specifically Ramp and Broadcast nodes, to avoid compile-time lane queries. The handling of scalar indices is also streamlined, ensuring clearer diagnostics when non-negativity cannot be proven. These changes enhance the robustness and clarity of index handling in the transformation pass.
parent 1e92d11c
...@@ -44,52 +44,28 @@ private: ...@@ -44,52 +44,28 @@ private:
PrimExpr simplified = analyzer_.Simplify(indices[i]); PrimExpr simplified = analyzer_.Simplify(indices[i]);
IndexSignState state = IndexSignState::kUnknown; IndexSignState state = IndexSignState::kUnknown;
// Handle scalar indices with the standard analyzer // Handle vector patterns first to avoid querying lanes() on
if (simplified.dtype().lanes() == 1) { // scalable vectors (which is not allowed at compile-time).
if (analyzer_.CanProve(simplified >= 0)) if (const auto *ramp = simplified.as<RampNode>()) {
// For scalable vectors, we cannot rely on a constant lane count.
// Use sufficient (but not necessary) conditions:
// - If base >= 0 and stride >= 0, all lanes are non-negative.
// - If base < 0 and stride <= 0, all lanes are negative.
bool base_nonneg = analyzer_.CanProve(ramp->base >= 0);
bool base_neg = analyzer_.CanProve(ramp->base < 0);
bool stride_nonneg = analyzer_.CanProve(ramp->stride >= 0);
bool stride_nonpos = analyzer_.CanProve(ramp->stride <= 0);
if (base_nonneg && stride_nonneg) {
state = IndexSignState::kNonNegative; state = IndexSignState::kNonNegative;
else if (analyzer_.CanProve(simplified < 0)) } else if (base_neg && stride_nonpos) {
state = IndexSignState::kNegative; state = IndexSignState::kNegative;
else } else {
DLOG(WARNING) DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index " << "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i << simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ")."; << ", index " + indices[i]->Script() + ").";
} }
// Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes).
else if (const auto *ramp = simplified.as<RampNode>()) {
// Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
auto base_bound = analyzer_.const_int_bound(ramp->base);
auto stride_bound = analyzer_.const_int_bound(ramp->stride);
int lanes = *as_const_int(ramp->lanes);
int64_t base_min = base_bound->min_value;
int64_t base_max = base_bound->max_value;
int64_t s_min = stride_bound->min_value;
int64_t s_max = stride_bound->max_value;
// Guard against overflow is not strictly necessary here because
// bounds may be +/-inf represented by sentinel values.
int64_t lower = base_min;
if (s_min < 0)
lower += s_min * (lanes - 1);
int64_t upper = base_max;
if (s_max > 0)
upper += s_max * (lanes - 1);
if (lower >= 0)
state = IndexSignState::kNonNegative;
else if (upper < 0)
state = IndexSignState::kNegative;
else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
} else if (const auto *broadcast = simplified.as<BroadcastNode>()) { } else if (const auto *broadcast = simplified.as<BroadcastNode>()) {
auto v = analyzer_.Simplify(broadcast->value); auto v = analyzer_.Simplify(broadcast->value);
if (analyzer_.CanProve(v >= 0)) if (analyzer_.CanProve(v >= 0))
...@@ -109,6 +85,20 @@ private: ...@@ -109,6 +85,20 @@ private:
<< simplified << " for buffer " << buffer_name << " (axis " << i << simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ")."; << ", index " + indices[i]->Script() + ").";
} }
} else {
// Assume scalar (or non-Ramp/Broadcast) index; avoid querying lanes().
// Fall back to scalar reasoning. If this expression is actually a
// vector-but-not-Ramp/Broadcast, treat as unknown to be safe.
// Try to prove scalar first; if proof fails, leave as unknown.
if (analyzer_.CanProve(simplified >= 0))
state = IndexSignState::kNonNegative;
else if (analyzer_.CanProve(simplified < 0))
state = IndexSignState::kNegative;
else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
} }
states.push_back(state); states.push_back(state);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment