"git@developer.sourcefind.cn:xdb4_94051/vllm.git" did not exist on "46958cf941997264ff36c23c42a71d30674b61f0"
Unverified Commit 055f8500 authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Feat] Add swap like grammar in tuple assignment (#1185)

* [Feat] add 2 phase binding to allow swap two var

* Minor update tvm dtype constructor

* fix lint error
parent 7d961892
......@@ -273,5 +273,35 @@ def test_prim_func_generator():
assert isinstance(foo, T.PrimFunc)
def test_swap_logic():
@tilelang.jit
@T.prim_func
def swap_var(A: T.Tensor[(2,), T.float32]):
with T.Kernel(1, threads=1) as _:
a = T.alloc_var(T.float32, A[0])
b = T.alloc_var(T.float32, A[1])
a, b = b, a
A[0], A[1] = a, b
@tilelang.jit
@T.prim_func
def swap_idx(A: T.Tensor[(2,), T.float32]):
with T.Kernel(1, threads=1) as _:
A[0], A[1] = A[1], A[0]
k_swap_var = swap_var()
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
k_swap_var(data)
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
torch.testing.assert_close(data, ref)
k_swap_idx = swap_idx()
data = torch.tensor([1.0, 2.0], dtype=torch.float32).cuda()
k_swap_idx(data)
ref = torch.tensor([2.0, 1.0], dtype=torch.float32).cuda()
torch.testing.assert_close(data, ref)
if __name__ == '__main__':
tilelang.testing.main()
......@@ -353,6 +353,8 @@ class DSLMutator(ast.NodeTransformer):
span=target,
)
else:
# flatten nested tuple into a list of (tmp_name, target)
unpacked = []
def _visit_target(target: ast.expr) -> str:
......@@ -367,6 +369,9 @@ class DSLMutator(ast.NodeTransformer):
res = ast.Tuple(elts=elts, ctx=target.ctx)
ast_set_span(res, ast_get_span(target))
return res
else:
s = ast.unparse(target)
raise NotImplementedError(f'Attribute assignment not supported yet, `{s}`')
unpack_stmt = ast.Assign(
targets=[_visit_target(target)],
......@@ -383,6 +388,26 @@ class DSLMutator(ast.NodeTransformer):
bind_lvals.clear()
bind_rvals.clear()
# the following code generate two phase binding to support swap like semantics
# for example:
# a, b = b, a
# 1 phase:
# _tmp_0, _tmp_1 = b, a
# => _tmp_0: T.int32 = b
# => _tmp_1: T.int32 = a
# 2 phase:
# a, b = _tmp_0, _tmp_1
# => a = _tmp_0 => a[0] = _tmp_0
# => b = _tmp_1 => b[0] = _tmp_1
# 1 phase: _tmp_0, _tmp_1 = __tb.bind('_', a), __tb.bind('_', b)
for tmp, _target in unpacked:
bind_lvals.append(tmp)
bind_rvals.append(f'__tb.bind("_", {tmp})')
flush_binds()
# 2 phase: a, b = __tb.bind('a', _tmp_0), __tb.bind('b', _tmp_1)
for tmp, target in unpacked:
if isinstance(target, ast.Name):
bind_lvals.append(target.id)
......
......@@ -320,6 +320,9 @@ class Builder(BaseBuilder):
return value
def bind_immutable(self, name, value):
if name == '_':
# use _tmp to make the generated tir more readable
name = "_tmp"
if isinstance(value, tir.meta_var):
return value.value
elif isinstance(value, tir.frame.IRBuilderFrame):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment