Unverified Commit dd7fdb8e authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Feat] add support for passing reference in T.Var annotation (#1291)

parent bccb6485
...@@ -361,5 +361,39 @@ def test_while_loop(): ...@@ -361,5 +361,39 @@ def test_while_loop():
assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}" assert A[0].item() == sum(range(10)), f"Expected {sum(range(10))}, but got {A[0].item()}"
def test_var_macro():
try:
@T.macro
def macro_with_var(x: T.Var):
x = 1 # noqa: F841
@T.prim_func
def prim_call_macro():
with T.Kernel(1):
x = T.alloc_var(T.int32)
macro_with_var(x)
assert 'x[0] = 1' in prim_call_macro.script()
finally:
pass
try:
@T.macro
def macro_with_var(x: T.Var):
x = 1 # noqa: F841
@T.prim_func
def prim_call_macro():
with T.Kernel(1):
x = 1
macro_with_var(x)
raise RuntimeError("Expect to report an error, x should not be passed as T.Var")
except ValueError:
pass
if __name__ == '__main__': if __name__ == '__main__':
tilelang.testing.main() tilelang.testing.main()
...@@ -140,6 +140,7 @@ class Builder(BaseBuilder): ...@@ -140,6 +140,7 @@ class Builder(BaseBuilder):
self.frames: list[AnyFrame] = [] self.frames: list[AnyFrame] = []
self.ir_builder = IRBuilder() self.ir_builder = IRBuilder()
self.name_inside_frame: dict[str, AnyFrame] = {} self.name_inside_frame: dict[str, AnyFrame] = {}
self.arg_annotations = {}
@classmethod @classmethod
def current(cls) -> Self: def current(cls) -> Self:
...@@ -155,16 +156,17 @@ class Builder(BaseBuilder): ...@@ -155,16 +156,17 @@ class Builder(BaseBuilder):
yield yield
@contextmanager @contextmanager
def macro(self, name=None): def macro(self, name=None, annotations=None):
if self.find_frame_idx(BoolOpFrame) is not None: if self.find_frame_idx(BoolOpFrame) is not None:
raise RuntimeError( raise RuntimeError(
f"Macro `{name}` is used inside boolean expressions, " f"Macro `{name}` is used inside boolean expressions, "
"please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs") "please use `if` to replace `M and M`, `M or M`, `M if xxx else M` constructs")
save = self.name_inside_frame save = self.name_inside_frame, self.arg_annotations
self.name_inside_frame = {} self.name_inside_frame = {}
self.arg_annotations = annotations or {}
with self.with_frame(MacroFrame()): with self.with_frame(MacroFrame()):
yield yield
self.name_inside_frame = save self.name_inside_frame, self.arg_annotations = save
def get(self): def get(self):
return self.ir_builder.get() return self.ir_builder.get()
...@@ -313,32 +315,18 @@ class Builder(BaseBuilder): ...@@ -313,32 +315,18 @@ class Builder(BaseBuilder):
self.check_continue_break() self.check_continue_break()
locals = self.get_parent_locals() locals = self.get_parent_locals()
orig_value = locals.get(name, None) orig_value = locals.get(name, None)
# annotation like tl.float32
# temporarily disable annotation based var declaration, for better pull request separation
# if callable(annot):
# annot_val = annot()
# if isinstance(annot_val, tir.Var):
# orig_value = tir.alloc_buffer((1,), dtype=annot_val.dtype, scope='local.var')
# IRBuilder.name(name, orig_value)
# if isinstance(value, EllipsisType) or value is self.empty:
# return orig_value
# elif isinstance(value, (int, float, IntImm, FloatImm)):
# tir.block_attr(
# {'tl.local_var_init': {
# orig_value.data: tvm.runtime.convert(value)
# }})
# return orig_value
# if orig_value is a local.var, we use buffer_store to modify it immutably # if orig_value is a local.var, we use buffer_store to modify it immutably
# however, if rvalue is also a local.var, this is a new binding, # however, if rvalue is not a PrimExpr, such as buffer,
# we should not use buffer_store, and bind it instead # we should not use buffer_store, and bind it instead
# ```py # ```py
# a = tl.alloc_var('float32') # bind var `a` # a = tl.alloc_var('float32') # bind var `a`
# a = tl.alloc_var('float32') # bind a new var `a_1` # a = tl.alloc_var('float32') # bind a new var `a_1`
# a = tl.alloc_shared((1,), T.float32) # bind a to new buffer
# b = a # get value of var `b = a_1[0]`` # b = a # get value of var `b = a_1[0]``
# c = tl.alloc_var('float32') # bind var `c` # c = tl.alloc_var('float32') # bind var `c`
# c = a # get and assign `c[0] = a_1[0]` # c = a # get and assign `c[0] = a_1[0]`
# ``` # ```
if is_var(orig_value) and not is_var(value): if is_var(orig_value) and isinstance(value, (int, float, PrimExpr)):
tir.buffer_store(orig_value, value, 0) tir.buffer_store(orig_value, value, 0)
return orig_value return orig_value
res = self.bind_immutable(name, value) res = self.bind_immutable(name, value)
...@@ -486,22 +474,34 @@ class Builder(BaseBuilder): ...@@ -486,22 +474,34 @@ class Builder(BaseBuilder):
) )
return self.unwrap_value(value) return self.unwrap_value(value)
def arg(self, name, value): def macro_arg(self, name, value):
if self.find_frame_idx(MacroFrame) is not None: if self.arg_annotations.get(name, None) is Var:
if isinstance(value, (PrimExpr, int, float)): is_var = isinstance(value, tvm.tir.BufferLoad) and value.buffer.scope() == 'local.var'
return self.bind(name, value) if not is_var:
else: raise ValueError(
return value f'Argument `{name}` is expected to be a variable allocated by `T.alloc_var`, but got {value}({type(value)})'
)
return value.buffer
elif isinstance(value, (PrimExpr, int, float)):
return self.bind(name, value)
else:
return value
def prim_func_arg(self, name, value):
if isinstance(value, (Buffer, Var)): if isinstance(value, (Buffer, Var)):
return tir.arg(name, value) return tir.arg(name, value)
elif value is self.empty: elif value is self.empty:
raise ValueError(f'Argument `{name}` is not annotated') raise ValueError(f'Argument `{name}` is not annotated')
# elif isinstance(value, Hashable):
# return value
else: else:
raise TypeError( raise TypeError(
f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.") f"Unsupported argument type: {value}({type(value)}) for argument `{name}`.")
def arg(self, name, value):
if self.find_frame_idx(MacroFrame) is not None:
return self.macro_arg(name, value)
else:
return self.prim_func_arg(name, value)
def override(self, name: str): def override(self, name: str):
from tilelang.language import serial from tilelang.language import serial
if name == 'range': if name == 'range':
...@@ -533,6 +533,7 @@ class Macro(Generic[_P, _T]): ...@@ -533,6 +533,7 @@ class Macro(Generic[_P, _T]):
name: str name: str
orig_func: Callable[_P, _T] orig_func: Callable[_P, _T]
ir_gen: IRGenerator[_P, _T] ir_gen: IRGenerator[_P, _T]
annotations: dict[str, Any]
@property @property
def source(self) -> str: def source(self) -> str:
...@@ -540,7 +541,7 @@ class Macro(Generic[_P, _T]): ...@@ -540,7 +541,7 @@ class Macro(Generic[_P, _T]):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
builder = Builder.current() builder = Builder.current()
with builder.macro(self.name): with builder.macro(self.name, self.annotations):
res = self.ir_gen.gen(builder)(*args, **kwargs) res = self.ir_gen.gen(builder)(*args, **kwargs)
return res return res
...@@ -578,7 +579,9 @@ def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]: ...@@ -578,7 +579,9 @@ def macro(func: Callable[_P, _T] = None) -> Macro[_P, _T]:
""" """
def impl(func: Callable[_P, _T]) -> Macro[_P, _T]: def impl(func: Callable[_P, _T]) -> Macro[_P, _T]:
return Macro(name=func.__name__, orig_func=func, ir_gen=mutate(func)) annotations = get_type_hints(func)
return Macro(
name=func.__name__, orig_func=func, ir_gen=mutate(func), annotations=annotations)
return impl(func) if func is not None else impl return impl(func) if func is not None else impl
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment