Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
399af087
Unverified
Commit
399af087
authored
Oct 28, 2025
by
Kurisu
Committed by
GitHub
Oct 28, 2025
Browse files
[BugFix] alloc_var init failed to handle complex expression (#1144)
* [Fix] init var with complex expression * fix lint error
parent
60567ba3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
2 deletions
+58
-2
testing/python/language/test_tilelang_language_var_init.py
testing/python/language/test_tilelang_language_var_init.py
+32
-0
tilelang/language/allocate.py
tilelang/language/allocate.py
+26
-2
No files found.
testing/python/language/test_tilelang_language_var_init.py
0 → 100644
View file @
399af087
import
tilelang
import
tilelang.language
as
T
import
tilelang.testing
def
test_var_assign
()
->
None
:
@
tilelang
.
jit
(
out_idx
=-
1
)
def
jit_kernel
():
@
T
.
prim_func
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
'int32'
)):
with
T
.
Kernel
(
1
)
as
_
:
a
=
T
.
alloc_var
(
'int32'
,
init
=
1
)
b
=
T
.
alloc_var
(
'int32'
,
init
=
a
)
# b gets value of a
a
=
2
d
=
T
.
alloc_var
(
'int32'
,
init
=
a
)
# c gets new value of a
A
[
0
]
=
b
A
[
1
]
=
d
print
(
test_var_assign
)
return
test_var_assign
kernel
=
jit_kernel
()
print
(
kernel
.
get_kernel_source
())
res
=
kernel
()
assert
res
[
0
]
==
1
assert
res
[
1
]
==
2
if
__name__
==
'__main__'
:
tilelang
.
testing
.
main
()
tilelang/language/allocate.py
View file @
399af087
...
...
@@ -15,10 +15,13 @@ with the appropriate memory scope.
"""
from
__future__
import
annotations
from
typing
import
overload
from
tilelang
import
tvm
as
tvm
from
tvm.script
import
tir
as
T
from
tvm.tir
import
PrimExpr
from
tvm.script.parser.tir
import
block_attr
from
tvm.tir.buffer
import
Buffer
from
tvm.tir.expr
import
FloatImm
,
IntImm
def
alloc_shared
(
shape
,
dtype
,
scope
=
"shared.dyn"
):
...
...
@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
@
overload
def
alloc_var
(
dtype
:
str
,
init
:
PrimExpr
|
int
|
float
,
scope
:
str
=
'local.var'
)
->
Buffer
:
...
@
overload
def
alloc_var
(
dtype
:
str
,
scope
:
str
=
'local.var'
,
*
,
init
:
PrimExpr
|
int
|
float
|
None
=
None
)
->
Buffer
:
...
def
alloc_var
(
dtype
,
*
args
,
scope
=
"local.var"
,
init
:
PrimExpr
|
None
=
None
):
"""Allocate a single-element variable buffer.
...
...
@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
init (PrimExpr, optional): The optional initializer value. When provided,
the generated code will initialize the variable with this value instead
of defaulting to zero.
Examples:
a = T.alloc_var('int32', 1) # var with init 1
a = T.alloc_var('int32', 'local.var') # var with local.var scope
a = T.alloc_var('int32', 1, 'local.var') # var with init 1 and local.var scope
a = T.alloc_var('int32', 'local.var', init=1) # var with init 1 and local.var scope
a = T.alloc_var('int32', init=1) # var with init 1 and local.var scope
Returns:
T.Buffer: A TVM buffer object allocated as a single-element variable
"""
...
...
@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
buffer
=
T
.
alloc_buffer
([
1
],
dtype
,
scope
=
parsed_scope
)
if
parsed_init
is
not
None
:
block_attr
({
"tl.local_var_init"
:
{
buffer
.
data
:
parsed_init
}})
if
isinstance
(
parsed_init
,
(
int
,
float
,
IntImm
,
FloatImm
)):
block_attr
({
"tl.local_var_init"
:
{
buffer
.
data
:
parsed_init
}})
else
:
T
.
buffer_store
(
buffer
,
parsed_init
,
0
)
return
buffer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment