Unverified Commit 4f844000 authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Fix] Fix missing `not` rewrite in frontend (#1348)

parent 17718bec
...@@ -466,5 +466,18 @@ def test_buffer_slice_step(): ...@@ -466,5 +466,18 @@ def test_buffer_slice_step():
pass pass
def test_boolop():
a = Var('a', 'int32')
b = Var('b', 'int32')
c = Var('c', 'int32')
d = Var('d', 'int32')
@T.macro
def cond():
return not (a < b and b < c and a * d < b * d) or b * d < c * d
cond()
if __name__ == '__main__': if __name__ == '__main__':
tilelang.testing.main() tilelang.testing.main()
...@@ -78,7 +78,7 @@ def quote_expr(expr: str, **kws) -> ast.expr: ...@@ -78,7 +78,7 @@ def quote_expr(expr: str, **kws) -> ast.expr:
Operator = Literal['Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift', Operator = Literal['Add', 'Sub', 'Mult', 'MatMult', 'Div', 'Mod', 'Pow', 'LShift', 'RShift',
'BitOr', 'BitXor', 'BitAnd', 'FloorDiv'] 'BitOr', 'BitXor', 'BitAnd', 'FloorDiv']
BoolOp = Literal['And', 'Or'] BoolOp = Literal['And', 'Or', 'Not']
def get_operator_name(operator: ast.operator) -> Operator: def get_operator_name(operator: ast.operator) -> Operator:
...@@ -217,11 +217,13 @@ class BaseBuilder: ...@@ -217,11 +217,13 @@ class BaseBuilder:
def aug_assign_slice(self, op: Operator, target: Any, sl: slice, aug_value: Any): def aug_assign_slice(self, op: Operator, target: Any, sl: slice, aug_value: Any):
eval_aug_assign(op, target, sl, aug_value) eval_aug_assign(op, target, sl, aug_value)
def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any]) -> Any: def boolop(self, op: BoolOp, left: Any, right: Callable[[], Any] | None = None) -> Any:
if op == 'And': if op == 'And':
return left and right() return left and right()
if op == 'Or': if op == 'Or':
return left or right() return left or right()
if op == 'Not':
return not left
raise ValueError(f'Unknown boolop: {op}') raise ValueError(f'Unknown boolop: {op}')
def ifexp(self, cond: Any, then: Callable[[], Any], otherwise: Callable[[], Any]) -> Any: def ifexp(self, cond: Any, then: Callable[[], Any], otherwise: Callable[[], Any]) -> Any:
...@@ -517,6 +519,12 @@ class DSLMutator(ast.NodeTransformer): ...@@ -517,6 +519,12 @@ class DSLMutator(ast.NodeTransformer):
) )
return last return last
def visit_UnaryOp(self, node: ast.UnaryOp):
node = self.generic_visit(node)
if isinstance(node.op, ast.Not):
return quote_expr("__tb.boolop('Not', operand)", operand=node.operand, span=node)
return node
def visit_Compare(self, node: ast.Compare) -> ast.expr: def visit_Compare(self, node: ast.Compare) -> ast.expr:
node = self.generic_visit(node) node = self.generic_visit(node)
left = node.left left = node.left
......
...@@ -148,8 +148,7 @@ class Builder(BaseBuilder): ...@@ -148,8 +148,7 @@ class Builder(BaseBuilder):
@classmethod @classmethod
def current(cls) -> Self: def current(cls) -> Self:
builder = thread_local_storage.builder builder = getattr(thread_local_storage, 'builder', None)
assert builder is not None, "No active Builder found in the current thread."
return builder return builder
@contextmanager @contextmanager
...@@ -424,7 +423,7 @@ class Builder(BaseBuilder): ...@@ -424,7 +423,7 @@ class Builder(BaseBuilder):
else: else:
return super().aug_assign_slice(op, target, sl, aug_value) return super().aug_assign_slice(op, target, sl, aug_value)
def boolop(self, op, left, right): def boolop(self, op, left, right=None):
left = unwrap_cond(left) left = unwrap_cond(left)
if isinstance(left, PrimExpr): if isinstance(left, PrimExpr):
with self.with_frame(BoolOpFrame()): with self.with_frame(BoolOpFrame()):
...@@ -432,6 +431,8 @@ class Builder(BaseBuilder): ...@@ -432,6 +431,8 @@ class Builder(BaseBuilder):
return tir.And(left, right()) return tir.And(left, right())
if op == 'Or': if op == 'Or':
return tir.Or(left, right()) return tir.Or(left, right())
if op == 'Not':
return tir.Not(left)
raise RuntimeError(f"Unsupported boolean operator: {op}") raise RuntimeError(f"Unsupported boolean operator: {op}")
else: else:
return super().boolop(op, left, right) return super().boolop(op, left, right)
...@@ -562,7 +563,7 @@ class Macro(Generic[_P, _T]): ...@@ -562,7 +563,7 @@ class Macro(Generic[_P, _T]):
return self.ir_gen.source return self.ir_gen.source
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
builder = Builder.current() builder = Builder.current() or Builder()
with builder.macro(self.name, self.annotations): with builder.macro(self.name, self.annotations):
res = self.ir_gen.gen(builder)(*args, **kwargs) res = self.ir_gen.gen(builder)(*args, **kwargs)
return res return res
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment