Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
399af087
Unverified
Commit
399af087
authored
Oct 28, 2025
by
Kurisu
Committed by
GitHub
Oct 28, 2025
Browse files
[BugFix] alloc_var init failed to handle complex expression (#1144)
* [Fix] init var with complex expression * fix lint error
parent
60567ba3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
2 deletions
+58
-2
testing/python/language/test_tilelang_language_var_init.py
testing/python/language/test_tilelang_language_var_init.py
+32
-0
tilelang/language/allocate.py
tilelang/language/allocate.py
+26
-2
No files found.
testing/python/language/test_tilelang_language_var_init.py
0 → 100644
View file @
399af087
import
tilelang
import
tilelang.language
as
T
import
tilelang.testing
def
test_var_assign
()
->
None
:
@
tilelang
.
jit
(
out_idx
=-
1
)
def
jit_kernel
():
@
T
.
prim_func
def
test_var_assign
(
A
:
T
.
Tensor
((
2
,),
'int32'
)):
with
T
.
Kernel
(
1
)
as
_
:
a
=
T
.
alloc_var
(
'int32'
,
init
=
1
)
b
=
T
.
alloc_var
(
'int32'
,
init
=
a
)
# b gets value of a
a
=
2
d
=
T
.
alloc_var
(
'int32'
,
init
=
a
)
# c gets new value of a
A
[
0
]
=
b
A
[
1
]
=
d
print
(
test_var_assign
)
return
test_var_assign
kernel
=
jit_kernel
()
print
(
kernel
.
get_kernel_source
())
res
=
kernel
()
assert
res
[
0
]
==
1
assert
res
[
1
]
==
2
if
__name__
==
'__main__'
:
tilelang
.
testing
.
main
()
tilelang/language/allocate.py
View file @
399af087
...
@@ -15,10 +15,13 @@ with the appropriate memory scope.
...
@@ -15,10 +15,13 @@ with the appropriate memory scope.
"""
"""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
overload
from
tilelang
import
tvm
as
tvm
from
tilelang
import
tvm
as
tvm
from
tvm.script
import
tir
as
T
from
tvm.script
import
tir
as
T
from
tvm.tir
import
PrimExpr
from
tvm.tir
import
PrimExpr
from
tvm.script.parser.tir
import
block_attr
from
tvm.script.parser.tir
import
block_attr
from
tvm.tir.buffer
import
Buffer
from
tvm.tir.expr
import
FloatImm
,
IntImm
def
alloc_shared
(
shape
,
dtype
,
scope
=
"shared.dyn"
):
def
alloc_shared
(
shape
,
dtype
,
scope
=
"shared.dyn"
):
...
@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
...
@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
return
T
.
alloc_buffer
(
shape
,
dtype
,
scope
=
scope
)
@
overload
def
alloc_var
(
dtype
:
str
,
init
:
PrimExpr
|
int
|
float
,
scope
:
str
=
'local.var'
)
->
Buffer
:
...
@
overload
def
alloc_var
(
dtype
:
str
,
scope
:
str
=
'local.var'
,
*
,
init
:
PrimExpr
|
int
|
float
|
None
=
None
)
->
Buffer
:
...
def
alloc_var
(
dtype
,
*
args
,
scope
=
"local.var"
,
init
:
PrimExpr
|
None
=
None
):
def
alloc_var
(
dtype
,
*
args
,
scope
=
"local.var"
,
init
:
PrimExpr
|
None
=
None
):
"""Allocate a single-element variable buffer.
"""Allocate a single-element variable buffer.
...
@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
...
@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
init (PrimExpr, optional): The optional initializer value. When provided,
init (PrimExpr, optional): The optional initializer value. When provided,
the generated code will initialize the variable with this value instead
the generated code will initialize the variable with this value instead
of defaulting to zero.
of defaulting to zero.
Examples:
a = T.alloc_var('int32', 1) # var with init 1
a = T.alloc_var('int32', 'local.var') # var with local.var scope
a = T.alloc_var('int32', 1, 'local.var') # var with init 1 and local.var scope
a = T.alloc_var('int32', 'local.var', init=1) # var with init 1 and local.var scope
a = T.alloc_var('int32', init=1) # var with init 1 and local.var scope
Returns:
Returns:
T.Buffer: A TVM buffer object allocated as a single-element variable
T.Buffer: A TVM buffer object allocated as a single-element variable
"""
"""
...
@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
...
@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
buffer
=
T
.
alloc_buffer
([
1
],
dtype
,
scope
=
parsed_scope
)
buffer
=
T
.
alloc_buffer
([
1
],
dtype
,
scope
=
parsed_scope
)
if
parsed_init
is
not
None
:
if
parsed_init
is
not
None
:
if
isinstance
(
parsed_init
,
(
int
,
float
,
IntImm
,
FloatImm
)):
block_attr
({
"tl.local_var_init"
:
{
buffer
.
data
:
parsed_init
}})
block_attr
({
"tl.local_var_init"
:
{
buffer
.
data
:
parsed_init
}})
else
:
T
.
buffer_store
(
buffer
,
parsed_init
,
0
)
return
buffer
return
buffer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment