Unverified Commit 399af087 authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[BugFix] alloc_var init failed to handle complex expression (#1144)

* [Fix] init var with complex expression

* fix lint error
parent 60567ba3
import tilelang
import tilelang.language as T
import tilelang.testing
def test_var_assign() -> None:
@tilelang.jit(out_idx=-1)
def jit_kernel():
@T.prim_func
def test_var_assign(A: T.Tensor((2,), 'int32')):
with T.Kernel(1) as _:
a = T.alloc_var('int32', init=1)
b = T.alloc_var('int32', init=a) # b gets value of a
a = 2
d = T.alloc_var('int32', init=a) # c gets new value of a
A[0] = b
A[1] = d
print(test_var_assign)
return test_var_assign
kernel = jit_kernel()
print(kernel.get_kernel_source())
res = kernel()
assert res[0] == 1
assert res[1] == 2
if __name__ == '__main__':
tilelang.testing.main()
...@@ -15,10 +15,13 @@ with the appropriate memory scope. ...@@ -15,10 +15,13 @@ with the appropriate memory scope.
""" """
from __future__ import annotations from __future__ import annotations
from typing import overload
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.script import tir as T from tvm.script import tir as T
from tvm.tir import PrimExpr from tvm.tir import PrimExpr
from tvm.script.parser.tir import block_attr from tvm.script.parser.tir import block_attr
from tvm.tir.buffer import Buffer
from tvm.tir.expr import FloatImm, IntImm
def alloc_shared(shape, dtype, scope="shared.dyn"): def alloc_shared(shape, dtype, scope="shared.dyn"):
...@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"): ...@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
return T.alloc_buffer(shape, dtype, scope=scope) return T.alloc_buffer(shape, dtype, scope=scope)
@overload
def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = 'local.var') -> Buffer:
...
@overload
def alloc_var(dtype: str,
scope: str = 'local.var',
*,
init: PrimExpr | int | float | None = None) -> Buffer:
...
def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
"""Allocate a single-element variable buffer. """Allocate a single-element variable buffer.
...@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): ...@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
init (PrimExpr, optional): The optional initializer value. When provided, init (PrimExpr, optional): The optional initializer value. When provided,
the generated code will initialize the variable with this value instead the generated code will initialize the variable with this value instead
of defaulting to zero. of defaulting to zero.
Examples:
a = T.alloc_var('int32', 1) # var with init 1
a = T.alloc_var('int32', 'local.var') # var with local.var scope
a = T.alloc_var('int32', 1, 'local.var') # var with init 1 and local.var scope
a = T.alloc_var('int32', 'local.var', init=1) # var with init 1 and local.var scope
a = T.alloc_var('int32', init=1) # var with init 1 and local.var scope
Returns: Returns:
T.Buffer: A TVM buffer object allocated as a single-element variable T.Buffer: A TVM buffer object allocated as a single-element variable
""" """
...@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None): ...@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
buffer = T.alloc_buffer([1], dtype, scope=parsed_scope) buffer = T.alloc_buffer([1], dtype, scope=parsed_scope)
if parsed_init is not None: if parsed_init is not None:
block_attr({"tl.local_var_init": {buffer.data: parsed_init}}) if isinstance(parsed_init, (int, float, IntImm, FloatImm)):
block_attr({"tl.local_var_init": {buffer.data: parsed_init}})
else:
T.buffer_store(buffer, parsed_init, 0)
return buffer return buffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment