Unverified Commit 2426090f authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Fix] Remove unused let_bindings_ in CodeGenC to fix #1300 (#1305)

* [Feat] add missing support of uint32x2

* [Feat] Add `T.Ref` annotation and tests

* fix lint error

* minor update for error message on twice decl

* Remove unused let_bindings_ in CodeGenC to fix #1300
parent d4b6d094
Subproject commit 713e6ade56eaa72cc85d58d9228dd9f34cc2d03e
Subproject commit bc31e7ad9f9fafd7659dfabafe359fd55a0ffc1e
import tilelang
import tilelang.testing
import tilelang.language as T
def test_tilelang_intimm():
T.int32(0x7fffffff)
T.int32(-0x7fffffff - 1)
T.uint32(0xffffffff)
T.int64(0x7fffffffffffffff)
T.int64(-0x7fffffffffffffff - 1)
T.uint64(0xffffffffffffffff)
a = T.int32()
a & 0x7fffffff
a = T.uint32()
a & 0xffffffff
a = T.int64()
a & 0x7fffffffffffffff
a = T.uint64()
a & T.uint64(0xffffffffffffffff)
if __name__ == '__main__':
tilelang.testing.main()
......@@ -394,6 +394,38 @@ def test_var_macro():
except ValueError:
pass
try:
@T.macro
def macro_with_var(x: T.Ref):
x = 1 # noqa: F841
@T.prim_func
def prim_call_macro():
with T.Kernel(1):
x = T.alloc_var(T.int32)
macro_with_var(x)
assert 'x[0] = 1' in prim_call_macro.script()
finally:
pass
try:
@T.macro
def macro_with_var(x: T.Ref):
x = 1 # noqa: F841
@T.prim_func
def prim_call_macro():
with T.Kernel(1):
x = 1
macro_with_var(x)
raise RuntimeError("Expect to report an error, x should not be passed as T.Var")
except ValueError:
pass
if __name__ == '__main__':
tilelang.testing.main()
......@@ -22,6 +22,7 @@ from .proxy import (
FragmentBuffer, # noqa: F401
SharedBuffer, # noqa: F401
LocalBuffer, # noqa: F401
Ref, # noqa: F401
)
from .loop import serial, Parallel, Persistent, Pipelined # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401
......
"""The language interface for tl programs."""
from __future__ import annotations
from typing import Any, SupportsIndex, TYPE_CHECKING
from typing import Any, SupportsIndex, TYPE_CHECKING, Generic, TypeVar
from collections.abc import Sequence
from typing_extensions import Self
......@@ -263,6 +263,11 @@ if TYPE_CHECKING:
class LocalBuffer(BaseTensor):
...
_T = TypeVar('_T')
class Ref(Generic[_T], tir.Var):
...
else:
Tensor = TensorProxy() # pylint: disable=invalid-name
StridedTensor = StridedTensorProxy() # pylint: disable=invalid-name
......@@ -270,6 +275,9 @@ else:
SharedBuffer = SharedBufferProxy() # pylint: disable=invalid-name
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
class Ref:
...
def ptr(dtype: str | None = None,
storage_scope: str = "global",
......
......@@ -335,7 +335,7 @@ class Builder(BaseBuilder):
assert frame is not None, f"Variable `{name}` is not defined inside any control flow."
if name in self.name_inside_frame and self.name_inside_frame[name] in self.frames:
logger.warning(
f'Variable `{name}` shadows another declared value, Are you forgetting to allocate it as a var?',
f'Variable `{name}` is declared twice, are you looking for a T.alloc_var?',
stack_info=True,
stacklevel=2,
)
......@@ -475,7 +475,11 @@ class Builder(BaseBuilder):
return self.unwrap_value(value)
def macro_arg(self, name, value):
if self.arg_annotations.get(name, None) is Var:
from tilelang.language.proxy import Ref
annot_value = self.arg_annotations.get(name, None)
if annot_value is Var or annot_value is Ref:
if annot_value is Var:
logger.warning('Use `T.Var` as macro annotations is deprecated, please use `T.Ref`')
is_var = isinstance(value, tvm.tir.BufferLoad) and value.buffer.scope() == 'local.var'
if not is_var:
raise ValueError(
......
......@@ -87,8 +87,12 @@ _STR_TO_TVM_DTYPE_CALL = {
'float8_e8m0fnu': 'Float8E8M0FNU'
}
int_ = int
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var:
if isinstance(expr, int_):
return tvm.tir.const(expr, dtype=self)
if self in _STR_TO_TVM_DTYPE_CALL:
attr = _STR_TO_TVM_DTYPE_CALL[self]
call = getattr(tb_ffi, attr, None)
......@@ -151,6 +155,10 @@ if TYPE_CHECKING:
class int16(dtype): ...
class int32(dtype): ...
class int64(dtype): ...
class int8x2(dtype): ...
class int16x2(dtype): ...
class int32x2(dtype): ...
class int64x2(dtype): ...
class int8x4(dtype): ...
class int16x4(dtype): ...
class int32x4(dtype): ...
......@@ -175,6 +183,10 @@ if TYPE_CHECKING:
class uint16(dtype): ...
class uint32(dtype): ...
class uint64(dtype): ...
class uint8x2(dtype): ...
class uint16x2(dtype): ...
class uint32x2(dtype): ...
class uint64x2(dtype): ...
class uint8x4(dtype): ...
class uint16x4(dtype): ...
class uint32x4(dtype): ...
......@@ -308,6 +320,10 @@ else:
int16 = dtype('int16')
int32 = dtype('int32')
int64 = dtype('int64')
int8x2 = dtype('int8x2')
int16x2 = dtype('int16x2')
int32x2 = dtype('int32x2')
int64x2 = dtype('int64x2')
int8x4 = dtype('int8x4')
int16x4 = dtype('int16x4')
int32x4 = dtype('int32x4')
......@@ -332,6 +348,10 @@ else:
uint16 = dtype('uint16')
uint32 = dtype('uint32')
uint64 = dtype('uint64')
uint8x2 = dtype('uint8x2')
uint16x2 = dtype('uint16x2')
uint32x2 = dtype('uint32x2')
uint64x2 = dtype('uint64x2')
uint8x4 = dtype('uint8x4')
uint16x4 = dtype('uint16x4')
uint32x4 = dtype('uint32x4')
......@@ -464,6 +484,10 @@ _all_dtypes = {
'int16',
'int32',
'int64',
'int8x2',
'int16x2',
'int32x2',
'int64x2',
'int8x4',
'int16x4',
'int32x4',
......@@ -488,6 +512,10 @@ _all_dtypes = {
'uint16',
'uint32',
'uint64',
'uint8x2',
'uint16x2',
'uint32x2',
'uint64x2',
'uint8x4',
'uint16x4',
'uint32x4',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment