Unverified Commit dd7fdb8e authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Feat] add support for passing reference in T.Var annotation (#1291)

parent bccb6485
......@@ -361,5 +361,39 @@ def test_while_loop():
assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}"
def test_var_macro():
try:
@T.macro
def macro_with_var(x: T.Var):
x = 1 # noqa: F841
@T.prim_func
def prim_call_macro():
with T.Kernel(1):
x = T.alloc_var(T.int32)
macro_with_var(x)
assert 'x[0] = 1' in prim_call_macro.script()
finally:
pass
try:
@T.macro
def macro_with_var(x: T.Var):
x = 1 # noqa: F841
@T.prim_func
def prim_call_macro():
with T.Kernel(1):
x = 1
macro_with_var(x)
raise RuntimeError("Expect to report an error, x should not be passed as T.Var")
except ValueError:
pass
if __name__ == '__main__':
tilelang.testing.main()
......@@ -140,6 +140,7 @@ class Builder(BaseBuilder):
self.frames: list[AnyFrame] = []
self.ir_builder = IRBuilder()
self.name_inside_frame: dict[str, AnyFrame] = {}
self.arg_annotations = {}
@classmethod
def current(cls) -> Self:
......@@ -155,16 +156,17 @@ class Builder(BaseBuilder):
yield
@contextmanager
def macro(self, name=None):
def macro(self, name=None, annotations=None):
if self.find_frame_idx(BoolOpFrame) is not None:
raise RuntimeError(
f"Macro `{name}` is used inside boolean expressions, "
"please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs")
save = self.name_inside_frame
save = self.name_inside_frame, self.arg_annotations
self.name_inside_frame = {}
self.arg_annotations = annotations or {}
with self.with_frame(MacroFrame()):
yield
self.name_inside_frame = save
self.name_inside_frame, self.arg_annotations = save
def get(self):
return self.ir_builder.get()
......@@ -313,32 +315,18 @@ class Builder(BaseBuilder):
self.check_continue_break()
locals = self.get_parent_locals()
orig_value = locals.get(name, None)
# annotation like tl.float32
# temporarily disable annotation based var declaration, for better pull request separation
# if callable(annot):
# annot_val = annot()
# if isinstance(annot_val, tir.Var):
# orig_value = tir.alloc_buffer((1,), dtype=annot_val.dtype, scope='local.var')
# IRBuilder.name(name, orig_value)
# if isinstance(value, EllipsisType) or value is self.empty:
# return orig_value
# elif isinstance(value, (int, float, IntImm, FloatImm)):
# tir.block_attr(
# {'tl.local_var_init': {
# orig_value.data: tvm.runtime.convert(value)
# }})
# return orig_value
# if orig_value is a local.var, we use buffer_store to modify it immutably
# however, if rvalue is also a local.var, this is a new binding,
# however, if rvalue is not a PrimExpr, such as buffer,
# we should not use buffer_store, and bind it instead
# ```py
# a = tl.alloc_var('float32') # bind var `a`
# a = tl.alloc_var('float32') # bind a new var `a_1`
# a = tl.alloc_shared((1,), T.float32) # bind a to new buffer
# b = a # get value of var `b = a_1[0]``
# c = tl.alloc_var('float32') # bind var `c`
# c = a # get and assign `c[0] = a_1[0]`
# ```
if is_var(orig_value) and not is_var(value):
if is_var(orig_value) and isinstance(value, (int, float, PrimExpr)):
tir.buffer_store(orig_value, value, 0)
return orig_value
res = self.bind_immutable(name, value)
......@@ -486,22 +474,34 @@ class Builder(BaseBuilder):
)
return self.unwrap_value(value)
def arg(self, name, value):
if self.find_frame_idx(MacroFrame) is not None:
if isinstance(value, (PrimExpr, int, float)):
def macro_arg(self, name, value):
if self.arg_annotations.get(name, None) is Var:
is_var = isinstance(value, tvm.tir.BufferLoad) and value.buffer.scope() == 'local.var'
if not is_var:
raise ValueError(
f'Argument `{name}` is expected to be a variable allocated by `T.alloc_var`, but got {value}({type(value)})'
)
return value.buffer
elif isinstance(value, (PrimExpr, int, float)):
return self.bind(name, value)
else:
return value
def prim_func_arg(self, name, value):
if isinstance(value, (Buffer, Var)):
return tir.arg(name, value)
elif value is self.empty:
raise ValueError(f'Argument `{name}` is not annotated')
# elif isinstance(value, Hashable):
# return value
else:
raise TypeError(
f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")
def arg(self, name, value):
if self.find_frame_idx(MacroFrame) is not None:
return self.macro_arg(name, value)
else:
return self.prim_func_arg(name, value)
def override(self, name: str):
from tilelang.language import serial
if name == 'range':
......@@ -533,6 +533,7 @@ class Macro(Generic[_P, _T]):
name: str
orig_func: Callable[_P, _T]
ir_gen: IRGenerator[_P, _T]
annotations: dict[str, Any]
@property
def source(self) -> str:
......@@ -540,7 +541,7 @@ class Macro(Generic[_P, _T]):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
builder = Builder.current()
with builder.macro(self.name):
with builder.macro(self.name, self.annotations):
res = self.ir_gen.gen(builder)(*args, **kwargs)
return res
......@@ -578,7 +579,9 @@ def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]:
"""
def impl(func: Callable[_P, _T]) -> Macro[_P, _T]:
return Macro(name=func.__name__, orig_func=func, ir_gen=mutate(func))
annotations = get_type_hints(func)
return Macro(
name=func.__name__, orig_func=func, ir_gen=mutate(func), annotations=annotations)
return impl(func) if func is not None else impl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment