Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
dd7fdb8e
Unverified
Commit
dd7fdb8e
authored
Nov 20, 2025
by
Kuris
Committed by
GitHub
Nov 20, 2025
Browse files
[Feat] add support for passing reference in T.Var annotation (#1291)
parent
bccb6485
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
30 deletions
+67
-30
testing/python/language/test_tilelang_language_frontend_v2.py
...ing/python/language/test_tilelang_language_frontend_v2.py
+34
-0
tilelang/language/v2/builder.py
tilelang/language/v2/builder.py
+33
-30
No files found.
testing/python/language/test_tilelang_language_frontend_v2.py
View file @
dd7fdb8e
...
@@ -361,5 +361,39 @@ def test_while_loop():
...
@@ -361,5 +361,39 @@ def test_while_loop():
assert
A
[
0
].
item
()
==
sum
(
range
(
10
)),
f
"Expected
{
sum
(
range
(
10
))
}
, but got
{
A
[
0
].
item
()
}
"
assert
A
[
0
].
item
()
==
sum
(
range
(
10
)),
f
"Expected
{
sum
(
range
(
10
))
}
, but got
{
A
[
0
].
item
()
}
"
def
test_var_macro
():
try
:
@
T
.
macro
def
macro_with_var
(
x
:
T
.
Var
):
x
=
1
# noqa: F841
@
T
.
prim_func
def
prim_call_macro
():
with
T
.
Kernel
(
1
):
x
=
T
.
alloc_var
(
T
.
int32
)
macro_with_var
(
x
)
assert
'x[0] = 1'
in
prim_call_macro
.
script
()
finally
:
pass
try
:
@
T
.
macro
def
macro_with_var
(
x
:
T
.
Var
):
x
=
1
# noqa: F841
@
T
.
prim_func
def
prim_call_macro
():
with
T
.
Kernel
(
1
):
x
=
1
macro_with_var
(
x
)
raise
RuntimeError
(
"Expect to report an error, x should not be passed as T.Var"
)
except
ValueError
:
pass
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
tilelang/language/v2/builder.py
View file @
dd7fdb8e
...
@@ -140,6 +140,7 @@ class Builder(BaseBuilder):
...
@@ -140,6 +140,7 @@ class Builder(BaseBuilder):
self
.
frames
:
list
[
AnyFrame
]
=
[]
self
.
frames
:
list
[
AnyFrame
]
=
[]
self
.
ir_builder
=
IRBuilder
()
self
.
ir_builder
=
IRBuilder
()
self
.
name_inside_frame
:
dict
[
str
,
AnyFrame
]
=
{}
self
.
name_inside_frame
:
dict
[
str
,
AnyFrame
]
=
{}
self
.
arg_annotations
=
{}
@
classmethod
@
classmethod
def
current
(
cls
)
->
Self
:
def
current
(
cls
)
->
Self
:
...
@@ -155,16 +156,17 @@ class Builder(BaseBuilder):
...
@@ -155,16 +156,17 @@ class Builder(BaseBuilder):
yield
yield
@
contextmanager
@
contextmanager
def
macro
(
self
,
name
=
None
):
def
macro
(
self
,
name
=
None
,
annotations
=
None
):
if
self
.
find_frame_idx
(
BoolOpFrame
)
is
not
None
:
if
self
.
find_frame_idx
(
BoolOpFrame
)
is
not
None
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Macro `
{
name
}
` is used inside boolean expressions, "
f
"Macro `
{
name
}
` is used inside boolean expressions, "
"please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs"
)
"please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs"
)
save
=
self
.
name_inside_frame
save
=
self
.
name_inside_frame
,
self
.
arg_annotations
self
.
name_inside_frame
=
{}
self
.
name_inside_frame
=
{}
self
.
arg_annotations
=
annotations
or
{}
with
self
.
with_frame
(
MacroFrame
()):
with
self
.
with_frame
(
MacroFrame
()):
yield
yield
self
.
name_inside_frame
=
save
self
.
name_inside_frame
,
self
.
arg_annotations
=
save
def
get
(
self
):
def
get
(
self
):
return
self
.
ir_builder
.
get
()
return
self
.
ir_builder
.
get
()
...
@@ -313,32 +315,18 @@ class Builder(BaseBuilder):
...
@@ -313,32 +315,18 @@ class Builder(BaseBuilder):
self
.
check_continue_break
()
self
.
check_continue_break
()
locals
=
self
.
get_parent_locals
()
locals
=
self
.
get_parent_locals
()
orig_value
=
locals
.
get
(
name
,
None
)
orig_value
=
locals
.
get
(
name
,
None
)
# annotation like tl.float32
# temporarily disable annotation based var declaration, for better pull request separation
# if callable(annot):
# annot_val = annot()
# if isinstance(annot_val, tir.Var):
# orig_value = tir.alloc_buffer((1,), dtype=annot_val.dtype, scope='local.var')
# IRBuilder.name(name, orig_value)
# if isinstance(value, EllipsisType) or value is self.empty:
# return orig_value
# elif isinstance(value, (int, float, IntImm, FloatImm)):
# tir.block_attr(
# {'tl.local_var_init': {
# orig_value.data: tvm.runtime.convert(value)
# }})
# return orig_value
# if orig_value is a local.var, we use buffer_store to modify it immutably
# if orig_value is a local.var, we use buffer_store to modify it immutably
# however, if rvalue is
also a local.var, this is a new binding
,
# however, if rvalue is
not a PrimExpr, such as buffer
,
# we should not use buffer_store, and bind it instead
# we should not use buffer_store, and bind it instead
# ```py
# ```py
# a = tl.alloc_var('float32') # bind var `a`
# a = tl.alloc_var('float32') # bind var `a`
# a = tl.alloc_var('float32') # bind a new var `a_1`
# a = tl.alloc_var('float32') # bind a new var `a_1`
# a = tl.alloc_shared((1,), T.float32) # bind a to new buffer
# b = a # get value of var `b = a_1[0]``
# b = a # get value of var `b = a_1[0]``
# c = tl.alloc_var('float32') # bind var `c`
# c = tl.alloc_var('float32') # bind var `c`
# c = a # get and assign `c[0] = a_1[0]`
# c = a # get and assign `c[0] = a_1[0]`
# ```
# ```
if
is_var
(
orig_value
)
and
not
is_var
(
value
):
if
is_var
(
orig_value
)
and
isinstance
(
value
,
(
int
,
float
,
PrimExpr
)
):
tir
.
buffer_store
(
orig_value
,
value
,
0
)
tir
.
buffer_store
(
orig_value
,
value
,
0
)
return
orig_value
return
orig_value
res
=
self
.
bind_immutable
(
name
,
value
)
res
=
self
.
bind_immutable
(
name
,
value
)
...
@@ -486,22 +474,34 @@ class Builder(BaseBuilder):
...
@@ -486,22 +474,34 @@ class Builder(BaseBuilder):
)
)
return
self
.
unwrap_value
(
value
)
return
self
.
unwrap_value
(
value
)
def
arg
(
self
,
name
,
value
):
def
macro_arg
(
self
,
name
,
value
):
if
self
.
find_frame_idx
(
MacroFrame
)
is
not
None
:
if
self
.
arg_annotations
.
get
(
name
,
None
)
is
Var
:
if
isinstance
(
value
,
(
PrimExpr
,
int
,
float
)):
is_var
=
isinstance
(
value
,
tvm
.
tir
.
BufferLoad
)
and
value
.
buffer
.
scope
()
==
'local.var'
if
not
is_var
:
raise
ValueError
(
f
'Argument `
{
name
}
` is expected to be a variable allocated by `T.alloc_var`, but got
{
value
}
(
{
type
(
value
)
}
)'
)
return
value
.
buffer
elif
isinstance
(
value
,
(
PrimExpr
,
int
,
float
)):
return
self
.
bind
(
name
,
value
)
return
self
.
bind
(
name
,
value
)
else
:
else
:
return
value
return
value
def
prim_func_arg
(
self
,
name
,
value
):
if
isinstance
(
value
,
(
Buffer
,
Var
)):
if
isinstance
(
value
,
(
Buffer
,
Var
)):
return
tir
.
arg
(
name
,
value
)
return
tir
.
arg
(
name
,
value
)
elif
value
is
self
.
empty
:
elif
value
is
self
.
empty
:
raise
ValueError
(
f
'Argument `
{
name
}
` is not annotated'
)
raise
ValueError
(
f
'Argument `
{
name
}
` is not annotated'
)
# elif isinstance(value, Hashable):
# return value
else
:
else
:
raise
TypeError
(
raise
TypeError
(
f
"Unsupported argument type:
{
value
}
(
{
type
(
value
)
}
) for argument `
{
name
}
`."
)
f
"Unsupported argument type:
{
value
}
(
{
type
(
value
)
}
) for argument `
{
name
}
`."
)
def
arg
(
self
,
name
,
value
):
if
self
.
find_frame_idx
(
MacroFrame
)
is
not
None
:
return
self
.
macro_arg
(
name
,
value
)
else
:
return
self
.
prim_func_arg
(
name
,
value
)
def
override
(
self
,
name
:
str
):
def
override
(
self
,
name
:
str
):
from
tilelang.language
import
serial
from
tilelang.language
import
serial
if
name
==
'range'
:
if
name
==
'range'
:
...
@@ -533,6 +533,7 @@ class Macro(Generic[_P, _T]):
...
@@ -533,6 +533,7 @@ class Macro(Generic[_P, _T]):
name
:
str
name
:
str
orig_func
:
Callable
[
_P
,
_T
]
orig_func
:
Callable
[
_P
,
_T
]
ir_gen
:
IRGenerator
[
_P
,
_T
]
ir_gen
:
IRGenerator
[
_P
,
_T
]
annotations
:
dict
[
str
,
Any
]
@
property
@
property
def
source
(
self
)
->
str
:
def
source
(
self
)
->
str
:
...
@@ -540,7 +541,7 @@ class Macro(Generic[_P, _T]):
...
@@ -540,7 +541,7 @@ class Macro(Generic[_P, _T]):
def
__call__
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_T
:
def
__call__
(
self
,
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
_T
:
builder
=
Builder
.
current
()
builder
=
Builder
.
current
()
with
builder
.
macro
(
self
.
name
):
with
builder
.
macro
(
self
.
name
,
self
.
annotations
):
res
=
self
.
ir_gen
.
gen
(
builder
)(
*
args
,
**
kwargs
)
res
=
self
.
ir_gen
.
gen
(
builder
)(
*
args
,
**
kwargs
)
return
res
return
res
...
@@ -578,7 +579,9 @@ def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]:
...
@@ -578,7 +579,9 @@ def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]:
"""
"""
def
impl
(
func
:
Callable
[
_P
,
_T
])
->
Macro
[
_P
,
_T
]:
def
impl
(
func
:
Callable
[
_P
,
_T
])
->
Macro
[
_P
,
_T
]:
return
Macro
(
name
=
func
.
__name__
,
orig_func
=
func
,
ir_gen
=
mutate
(
func
))
annotations
=
get_type_hints
(
func
)
return
Macro
(
name
=
func
.
__name__
,
orig_func
=
func
,
ir_gen
=
mutate
(
func
),
annotations
=
annotations
)
return
impl
(
func
)
if
func
is
not
None
else
impl
return
impl
(
func
)
if
func
is
not
None
else
impl
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment