Unverified Commit e59e7f9a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Support Consequential assignments like 'a = b = c = 1' (#992)



* chained assignments

* test update

* [Lint]: [pre-commit.ci] auto fixes [...]

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 0f515b86
import tilelang
import tilelang.testing
import tilelang.language as T
import torch
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},)
def chain_equal(N, block_size, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
B: T.Tensor((N,), dtype),
C: T.Tensor((N,), dtype),
):
with T.Kernel(T.ceildiv(N, block_size), threads=block_size) as bx:
for lane in T.Parallel(block_size):
idx = bx * block_size + lane
A[idx] = B[idx] = C[idx] = 1
return main
def run_chain_equal(N=128, block_size=64, dtype="float32"):
kernel = chain_equal(N, block_size, dtype)
A = torch.zeros((N,), dtype=torch.float32, device="cuda")
B = torch.zeros((N,), dtype=torch.float32, device="cuda")
C = torch.zeros((N,), dtype=torch.float32, device="cuda")
kernel(A, B, C)
ref = torch.ones_like(A)
torch.testing.assert_close(A, ref)
torch.testing.assert_close(B, ref)
torch.testing.assert_close(C, ref)
@tilelang.testing.requires_cuda
def test_chain_equal():
run_chain_equal()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -15,6 +15,59 @@ def _get_node_span(node: doc.AST) -> Tuple[int, int, int, int]: ...@@ -15,6 +15,59 @@ def _get_node_span(node: doc.AST) -> Tuple[int, int, int, int]:
return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset) return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset)
# Original implementation located at
# 3rdparty/tvm/python/tvm/script/parser/tir/parser.py (visit_assign).
@dispatch.register(token="tir", type_name="Assign")
def tilelang_visit_assign(self, node: doc.Assign) -> None: # pylint: disable=unused-argument
"""Override `Assign` to support chained writes and `local.var` buffers."""
if not node.targets:
self.report_error(node, "Assignment must have at least one target.")
if isinstance(node.value, doc.Subscript):
check_slices = []
if isinstance(node.value.slice, doc.Slice):
check_slices = [node.value.slice]
elif isinstance(node.value.slice, doc.Tuple):
for part in node.value.slice.elts:
if isinstance(part, doc.Slice):
check_slices.append(part)
for slice_node in check_slices:
if not slice_node.step and slice_node.upper and slice_node.lower:
slice_node.step = doc.Constant(
1,
None,
1,
1,
slice_node.upper.lineno,
slice_node.upper.end_col_offset + 1,
slice_node.upper.lineno,
slice_node.upper.end_col_offset + 2,
)
rhs = self.eval_expr(node.value)
for lhs in node.targets:
if isinstance(lhs, doc.Subscript):
if isinstance(lhs.slice, doc.Tuple):
indices = [self.eval_expr(index) for index in lhs.slice.elts]
else:
indices = self.eval_expr(lhs.slice)
T.buffer_store(self.eval_expr(lhs.value), rhs, indices)
continue
if isinstance(lhs, doc.Name) and lhs.id in self.var_table.get():
load_ctx = doc.Load()
store_ctx = doc.Store()
lhs.ctx = load_ctx
lhs_value = self.eval_expr(lhs)
lhs.ctx = store_ctx
if (isinstance(lhs_value, BufferLoad) and lhs_value.buffer.scope() == "local.var" and
len(lhs_value.indices) == 1 and lhs_value.indices[0] == 0):
T.buffer_store(lhs_value.buffer, rhs, indices=[0])
continue
self.eval_assign(target=lhs, source=rhs, bind_value=tvm_tir_parser.bind_assign_value)
# Original implementation located at # Original implementation located at
# 3rdparty/tvm/python/tvm/script/parser/tir/parser.py (visit_aug_assign). # 3rdparty/tvm/python/tvm/script/parser/tir/parser.py (visit_aug_assign).
@dispatch.register(token="tir", type_name="AugAssign") @dispatch.register(token="tir", type_name="AugAssign")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment