Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
0592834f
Unverified
Commit
0592834f
authored
Nov 06, 2025
by
Kurisu
Committed by
GitHub
Nov 06, 2025
Browse files
[Feat] Add A Pass to Handle Negative Index (#1192)
parent
777881e1
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
233 additions
and
0 deletions
+233
-0
src/transform/legalize_negative_index.cc
src/transform/legalize_negative_index.cc
+160
-0
testing/python/language/test_tilelang_language_negative_index.py
.../python/language/test_tilelang_language_negative_index.py
+60
-0
tilelang/engine/phase.py
tilelang/engine/phase.py
+2
-0
tilelang/transform/__init__.py
tilelang/transform/__init__.py
+11
-0
No files found.
src/transform/legalize_negative_index.cc
0 → 100644
View file @
0592834f
/*!
* \file legalize_negative_index.cc
* \brief Legalize negative indices in buffer load expressions.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <vector>
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
namespace
tvm
{
namespace
tl
{
using
namespace
tir
;
using
arith
::
IRVisitorWithAnalyzer
;
enum
class
IndexSignState
{
kNonNegative
,
kNegative
,
kUnknown
};
class
NegativeIndexAnalyzer
:
public
IRVisitorWithAnalyzer
{
public:
explicit
NegativeIndexAnalyzer
(
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
*
result
)
:
result_
(
result
)
{}
void
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
auto
load
=
tvm
::
ffi
::
GetRef
<
BufferLoad
>
(
op
);
std
::
vector
<
IndexSignState
>
states
;
states
.
reserve
(
op
->
indices
.
size
());
bool
needs_record
=
false
;
for
(
size_t
i
=
0
;
i
<
op
->
indices
.
size
();
++
i
)
{
PrimExpr
simplified
=
analyzer_
.
Simplify
(
op
->
indices
[
i
]);
if
(
analyzer_
.
CanProve
(
simplified
>=
0
))
{
states
.
push_back
(
IndexSignState
::
kNonNegative
);
continue
;
}
if
(
analyzer_
.
CanProve
(
simplified
<
0
))
{
states
.
push_back
(
IndexSignState
::
kNegative
);
needs_record
=
true
;
continue
;
}
states
.
push_back
(
IndexSignState
::
kUnknown
);
needs_record
=
true
;
LOG
(
WARNING
)
<<
"LegalizeNegativeIndex: cannot prove non-negative index "
<<
simplified
<<
" for buffer "
<<
load
->
buffer
->
name
<<
" (axis "
<<
i
<<
")."
;
}
if
(
needs_record
)
{
(
*
result_
)[
op
]
=
std
::
move
(
states
);
}
IRVisitorWithAnalyzer
::
VisitExpr_
(
op
);
}
private:
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
*
result_
;
};
class
NegativeIndexRewriter
:
public
arith
::
IRMutatorWithAnalyzer
{
public:
static
PrimFunc
Apply
(
PrimFunc
func
,
const
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
&
states
)
{
arith
::
Analyzer
analyzer
;
NegativeIndexRewriter
rewriter
(
&
analyzer
,
states
);
if
(
!
func
->
body
.
defined
())
{
return
func
;
}
PrimFuncNode
*
func_node
=
func
.
CopyOnWrite
();
func_node
->
body
=
rewriter
.
VisitStmt
(
func_node
->
body
);
return
func
;
}
private:
NegativeIndexRewriter
(
arith
::
Analyzer
*
analyzer
,
const
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
&
states
)
:
arith
::
IRMutatorWithAnalyzer
(
analyzer
),
states_
(
states
)
{}
PrimExpr
VisitExpr_
(
const
BufferLoadNode
*
op
)
final
{
BufferLoad
load
=
Downcast
<
BufferLoad
>
(
arith
::
IRMutatorWithAnalyzer
::
VisitExpr_
(
op
));
auto
it
=
states_
.
find
(
op
);
if
(
it
==
states_
.
end
())
{
return
load
;
}
auto
indices
=
load
->
indices
;
bool
changed
=
false
;
const
auto
&
state_vector
=
it
->
second
;
ICHECK_EQ
(
state_vector
.
size
(),
indices
.
size
())
<<
"State vector size mismatch for buffer load "
<<
load
->
buffer
->
name
;
for
(
size_t
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
if
(
state_vector
[
i
]
!=
IndexSignState
::
kNegative
)
{
continue
;
}
PrimExpr
extent
=
load
->
buffer
->
shape
[
i
];
indices
.
Set
(
i
,
analyzer_
->
Simplify
(
extent
+
indices
[
i
]));
changed
=
true
;
}
if
(
!
changed
)
{
return
load
;
}
return
BufferLoad
(
load
->
buffer
,
indices
);
}
const
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
&
states_
;
};
PrimFunc
LegalizeNegativeIndex
(
PrimFunc
func
)
{
if
(
!
func
->
body
.
defined
())
{
return
func
;
}
std
::
unordered_map
<
const
BufferLoadNode
*
,
std
::
vector
<
IndexSignState
>>
states
;
NegativeIndexAnalyzer
analyzer
(
&
states
);
analyzer
(
func
->
body
);
if
(
states
.
empty
())
{
return
func
;
}
return
NegativeIndexRewriter
::
Apply
(
std
::
move
(
func
),
states
);
}
tvm
::
transform
::
Pass
LegalizeNegativeIndexPass
()
{
using
namespace
tir
::
transform
;
auto
pass_func
=
[](
PrimFunc
f
,
const
IRModule
&
,
PassContext
)
{
return
LegalizeNegativeIndex
(
std
::
move
(
f
));
};
return
CreatePrimFuncPass
(
pass_func
,
0
,
"tl.LegalizeNegativeIndex"
,
{});
}
TVM_FFI_STATIC_INIT_BLOCK
()
{
namespace
refl
=
tvm
::
ffi
::
reflection
;
refl
::
GlobalDef
().
def
(
"tl.transform.LegalizeNegativeIndex"
,
LegalizeNegativeIndexPass
);
}
}
// namespace tl
}
// namespace tvm
testing/python/language/test_tilelang_language_negative_index.py
0 → 100644
View file @
0592834f
from
tilelang
import
tvm
import
tilelang
as
tl
import
tilelang.testing
from
tvm.script
import
tir
as
T
@
T
.
prim_func
def
negative_index_before
(
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
16
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
B
[
0
]
=
A
[
T
.
int32
(
-
1
)]
@
T
.
prim_func
def
negative_index_expected
(
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
16
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
B
[
0
]
=
A
[
T
.
int32
(
15
)]
@
T
.
prim_func
def
negative_index_loop_before
(
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
4
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
for
i
in
T
.
serial
(
4
):
B
[
i
]
=
A
[
-
i
-
1
]
@
T
.
prim_func
def
negative_index_loop_expected
(
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
4
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
for
i
in
T
.
serial
(
4
):
B
[
i
]
=
A
[
15
-
i
]
@
T
.
prim_func
def
negative_index_symbolic_before
(
shift
:
T
.
int32
,
A
:
T
.
Buffer
((
16
,),
"float32"
),
B
:
T
.
Buffer
((
16
,),
"float32"
)):
T
.
func_attr
({
"tir.noalias"
:
True
})
for
i
in
T
.
serial
(
16
):
B
[
i
]
=
A
[
shift
+
i
]
def
test_legalize_negative_index_scalar
():
mod
=
tvm
.
IRModule
({
"main"
:
negative_index_before
})
transformed
=
tl
.
transform
.
LegalizeNegativeIndex
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
transformed
[
"main"
].
body
,
negative_index_expected
.
body
)
def
test_legalize_negative_index_affine_expr
():
mod
=
tvm
.
IRModule
({
"main"
:
negative_index_loop_before
})
transformed
=
tl
.
transform
.
LegalizeNegativeIndex
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
transformed
[
"main"
].
body
,
negative_index_loop_expected
.
body
)
def
test_legalize_negative_index_symbolic_passthrough
():
mod
=
tvm
.
IRModule
({
"main"
:
negative_index_symbolic_before
})
transformed
=
tl
.
transform
.
LegalizeNegativeIndex
()(
mod
)
tvm
.
ir
.
assert_structural_equal
(
transformed
[
"main"
].
body
,
negative_index_symbolic_before
.
body
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang/engine/phase.py
View file @
0592834f
...
...
@@ -96,6 +96,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod
=
tilelang
.
transform
.
LetInline
()(
mod
)
# Add wrapper for single buf store
mod
=
tilelang
.
transform
.
AddWrapperForSingleBufStore
()(
mod
)
# Normalize negative indices to canonical non-negative form
mod
=
tilelang
.
transform
.
LegalizeNegativeIndex
()(
mod
)
# Inject assumes to speedup tvm prover
mod
=
tilelang
.
transform
.
InjectAssumes
()(
mod
)
# Simplify the IR expressions
...
...
tilelang/transform/__init__.py
View file @
0592834f
...
...
@@ -80,6 +80,17 @@ def FrontendLegalize():
return
_ffi_api
.
FrontendLegalize
()
# type: ignore
def
LegalizeNegativeIndex
():
"""Legalize negative indices in buffer loads.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return
_ffi_api
.
LegalizeNegativeIndex
()
# type: ignore
def
InjectAssumes
():
"""Inject Assumes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment