Unverified Commit 399af087 authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[BugFix] alloc_var init failed to handle complex expression (#1144)

* [Fix] init var with complex expression

* fix lint error
parent 60567ba3
import tilelang
import tilelang.language as T
import tilelang.testing
def test_var_assign() -> None:
@tilelang.jit(out_idx=-1)
def jit_kernel():
@T.prim_func
def test_var_assign(A: T.Tensor((2,), 'int32')):
with T.Kernel(1) as _:
a = T.alloc_var('int32', init=1)
b = T.alloc_var('int32', init=a) # b gets value of a
a = 2
d = T.alloc_var('int32', init=a) # c gets new value of a
A[0] = b
A[1] = d
print(test_var_assign)
return test_var_assign
kernel = jit_kernel()
print(kernel.get_kernel_source())
res = kernel()
assert res[0] == 1
assert res[1] == 2
if __name__ == '__main__':
tilelang.testing.main()
......@@ -15,10 +15,13 @@ with the appropriate memory scope.
"""
from __future__ import annotations
from typing import overload
from tilelang import tvm as tvm
from tvm.script import tir as T
from tvm.tir import PrimExpr
from tvm.script.parser.tir import block_attr
from tvm.tir.buffer import Buffer
from tvm.tir.expr import FloatImm, IntImm
def alloc_shared(shape, dtype, scope="shared.dyn"):
......@@ -67,6 +70,19 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
return T.alloc_buffer(shape, dtype, scope=scope)
@overload
def alloc_var(dtype: str, init: PrimExpr | int | float, scope: str = 'local.var') -> Buffer:
...
@overload
def alloc_var(dtype: str,
scope: str = 'local.var',
*,
init: PrimExpr | int | float | None = None) -> Buffer:
...
def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
"""Allocate a single-element variable buffer.
......@@ -82,7 +98,12 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
init (PrimExpr, optional): The optional initializer value. When provided,
the generated code will initialize the variable with this value instead
of defaulting to zero.
Examples:
a = T.alloc_var('int32', 1) # var with init 1
a = T.alloc_var('int32', 'local.var') # var with local.var scope
a = T.alloc_var('int32', 1, 'local.var') # var with init 1 and local.var scope
a = T.alloc_var('int32', 'local.var', init=1) # var with init 1 and local.var scope
a = T.alloc_var('int32', init=1) # var with init 1 and local.var scope
Returns:
T.Buffer: A TVM buffer object allocated as a single-element variable
"""
......@@ -113,7 +134,10 @@ def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
buffer = T.alloc_buffer([1], dtype, scope=parsed_scope)
if parsed_init is not None:
block_attr({"tl.local_var_init": {buffer.data: parsed_init}})
if isinstance(parsed_init, (int, float, IntImm, FloatImm)):
block_attr({"tl.local_var_init": {buffer.data: parsed_init}})
else:
T.buffer_store(buffer, parsed_init, 0)
return buffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment