Unverified Commit bf90a5f5 authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Fix] Fix frame scope error in T.macro (#1308)



* [Fix] Fix #1307 by adding macro inside function

* fix lint error

* add comments and fix lint error

* Remove debug print from enter_frame method

Removed debug print statement from enter_frame method.

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent 17bbc0ca
......@@ -427,5 +427,31 @@ def test_var_macro():
pass
def frame_inside_macro():
@tilelang.jit
def get_sample_kernel():
@T.macro
def transform(x):
return x + 1
@T.prim_func
def sample_kernel(
num_blocks: T.int32,
idx_out: T.Tensor[(32,), T.int32],
):
with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841
fragment = T.alloc_fragment(32, 'int32')
T.copy(idx_out, fragment)
for i in T.Parallel(32):
idx_out[i] = transform(fragment[i])
return sample_kernel
kernel = get_sample_kernel() # noqa: F841
if __name__ == '__main__':
tilelang.testing.main()
......@@ -80,6 +80,10 @@ class MacroFrame(Frame):
...
class ExitedMacroFrame(Frame):
...
class BoolOpFrame(Frame):
...
......@@ -164,8 +168,22 @@ class Builder(BaseBuilder):
save = self.name_inside_frame, self.arg_annotations
self.name_inside_frame = {}
self.arg_annotations = annotations or {}
with self.with_frame(MacroFrame()):
pos = len(self.frames)
# here we add a ExitedMacroFrame to preserve the frame stack inside macro
# because macro may bind some variable, and return it
#
# ```py
# @T.macro
# def foo(x):
# y = x + 1
# return y
# @T.prim_func
# def bar():
# c = foo(1) # macro generates let y = x + 1
# d = c # d = c should lay inside frame of `let y = x + 1`
self.frames.append(MacroFrame())
yield
self.frames[pos] = ExitedMacroFrame()
self.name_inside_frame, self.arg_annotations = save
def get(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment