"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "b53c91d00b32b4d79a565f4f6e9c0ef39bcfa3d1"
Unverified Commit 918a21bd authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Improve handling of negative indices for ramp and broadcast node (#1207)

* [Enhancement] Improve handling of negative indices in legalize_negative_index pass

* Added logic to handle scalar and vector indices separately, enhancing the ability to determine non-negativity and negativity of indices.
* Introduced detailed logging for cases where non-negativity cannot be proven, improving debugging capabilities.
* Refactored index state determination for vector types, including support for Ramp and Broadcast nodes.

* Fix incorrect lane handling in legalize_negative_index pass by dereferencing lanes to obtain the correct integer value.

* Enhance legalize_negative_index pass by including necessary header for TIR operations. This addition supports improved functionality and maintainability of the transformation logic.
parent 4818d209
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h> #include <tvm/runtime/logging.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
...@@ -37,12 +38,84 @@ public: ...@@ -37,12 +38,84 @@ public:
for (size_t i = 0; i < op->indices.size(); ++i) { for (size_t i = 0; i < op->indices.size(); ++i) {
PrimExpr simplified = analyzer_.Simplify(op->indices[i]); PrimExpr simplified = analyzer_.Simplify(op->indices[i]);
if (analyzer_.CanProve(simplified >= 0)) {
states.push_back(IndexSignState::kNonNegative); // Handle scalar indices with the standard analyzer
if (simplified.dtype().lanes() == 1) {
if (analyzer_.CanProve(simplified >= 0)) {
states.push_back(IndexSignState::kNonNegative);
continue;
}
if (analyzer_.CanProve(simplified < 0)) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
}
states.push_back(IndexSignState::kUnknown);
needs_record = true;
LOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name << " (axis "
<< i << ").";
continue; continue;
} }
if (analyzer_.CanProve(simplified < 0)) { // Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes).
IndexSignState vec_state = IndexSignState::kUnknown;
if (const auto *ramp = simplified.as<RampNode>()) {
// Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
auto base_bound = analyzer_.const_int_bound(ramp->base);
auto stride_bound = analyzer_.const_int_bound(ramp->stride);
int lanes = *as_const_int(ramp->lanes);
int64_t base_min = base_bound->min_value;
int64_t base_max = base_bound->max_value;
int64_t s_min = stride_bound->min_value;
int64_t s_max = stride_bound->max_value;
// Guard against overflow is not strictly necessary here because
// bounds may be +/-inf represented by sentinel values.
int64_t lower = base_min;
if (s_min < 0)
lower += s_min * (lanes - 1);
int64_t upper = base_max;
if (s_max > 0)
upper += s_max * (lanes - 1);
if (lower >= 0) {
vec_state = IndexSignState::kNonNegative;
} else if (upper < 0) {
vec_state = IndexSignState::kNegative;
} else {
vec_state = IndexSignState::kUnknown;
}
} else if (const auto *bc = simplified.as<BroadcastNode>()) {
auto v = analyzer_.Simplify(bc->value);
if (analyzer_.CanProve(v >= 0)) {
vec_state = IndexSignState::kNonNegative;
} else if (analyzer_.CanProve(v < 0)) {
vec_state = IndexSignState::kNegative;
} else {
// Try const bound if proof unavailable
auto vb = analyzer_.const_int_bound(v);
if (vb->min_value >= 0) {
vec_state = IndexSignState::kNonNegative;
} else if (vb->max_value < 0) {
vec_state = IndexSignState::kNegative;
} else {
vec_state = IndexSignState::kUnknown;
}
}
}
if (vec_state == IndexSignState::kNonNegative) {
states.push_back(IndexSignState::kNonNegative);
continue;
}
if (vec_state == IndexSignState::kNegative) {
states.push_back(IndexSignState::kNegative); states.push_back(IndexSignState::kNegative);
needs_record = true; needs_record = true;
continue; continue;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment