Unverified Commit 305c854b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Typing] Enhance compatibility for advanced typing features in Python (#1382)

- Updated `allocate.py` and `annot.py` to improve compatibility with Python 3.9 and later by conditionally importing advanced typing features such as `TypeVarTuple`, `Unpack`, and `ParamSpec`.
- Added fallback imports from `typing_extensions` for environments using earlier Python versions.
- Improved handling of generic alias detection to ensure consistent behavior across different Python versions.
parent ce16e479
...@@ -14,7 +14,12 @@ Each function takes shape and dtype parameters and returns a TVM buffer object ...@@ -14,7 +14,12 @@ Each function takes shape and dtype parameters and returns a TVM buffer object
with the appropriate memory scope. with the appropriate memory scope.
""" """
from __future__ import annotations from __future__ import annotations
from typing import TypeVarTuple, TypeVar, overload, Literal, Unpack, Callable from typing import TypeVar, overload, Literal, Callable
# Python 3.9 compatibility for advanced typing features (PEP 646)
try:
from typing import TypeVarTuple, Unpack # type: ignore[attr-defined]
except Exception:
from typing_extensions import TypeVarTuple, Unpack # type: ignore
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.script import tir as T from tvm.script import tir as T
from tvm.tir import PrimExpr from tvm.tir import PrimExpr
......
...@@ -4,7 +4,31 @@ from abc import ABC, abstractmethod ...@@ -4,7 +4,31 @@ from abc import ABC, abstractmethod
from tvm import tir from tvm import tir
from tvm.ir.expr import PrimExpr from tvm.ir.expr import PrimExpr
from tvm.script.ir_builder.tir import buffer from tvm.script.ir_builder.tir import buffer
from typing import Any, Callable, Literal, TypeVar, ParamSpec, Generic, TypeVarTuple, Unpack, TYPE_CHECKING, _GenericAlias, Self from typing import Any, Callable, Literal, TypeVar, Generic, TYPE_CHECKING
# Python 3.9 compatibility for advanced typing features
try:
from typing import ParamSpec, TypeVarTuple, Unpack, Self # type: ignore[attr-defined]
except Exception: # Python < 3.10 for ParamSpec, < 3.11 for Unpack/TypeVarTuple/Self
from typing_extensions import ParamSpec, TypeVarTuple, Unpack, Self # type: ignore
# Compatibility for generic alias detection across Python versions
try:
from typing import _GenericAlias as _TypingGenericAlias # type: ignore[attr-defined]
except Exception:
_TypingGenericAlias = None # type: ignore
try:
# Builtin generic alias type for e.g. tuple[int]
from types import GenericAlias as _TypesGenericAlias # type: ignore[attr-defined]
except Exception:
_TypesGenericAlias = None # type: ignore
_GenericAliasTypes = tuple(t for t in (_TypingGenericAlias, _TypesGenericAlias) if t is not None)
if not _GenericAliasTypes:
class _DummyGenericAlias: # type: ignore
pass
_GenericAliasTypes = (_DummyGenericAlias,) # type: ignore
from collections.abc import Sequence from collections.abc import Sequence
from .dtypes import AnyDType from .dtypes import AnyDType
from . import dtypes as dt from . import dtypes as dt
...@@ -116,7 +140,7 @@ class Value(Annot): ...@@ -116,7 +140,7 @@ class Value(Annot):
name = value.name if isinstance(value, tir.Var) else prefer_name name = value.name if isinstance(value, tir.Var) else prefer_name
return Value(kind='dynamic', name=name, dtype=value.dtype, value=value) return Value(kind='dynamic', name=name, dtype=value.dtype, value=value)
elif value is Any or value is None or value is dt.dtype or isinstance( elif value is Any or value is None or value is dt.dtype or isinstance(
value, (type, _GenericAlias)): value, (type,) + _GenericAliasTypes):
# A # no annotation # A # no annotation
# A: Any # A: Any
# A: _T # A: _T
...@@ -358,17 +382,6 @@ class BufferAnnot(Annot): ...@@ -358,17 +382,6 @@ class BufferAnnot(Annot):
buf = buffer(shape, self.dtype, strides=strides, scope=self.scope) buf = buffer(shape, self.dtype, strides=strides, scope=self.scope)
return TIRAnnot(data=buf) return TIRAnnot(data=buf)
# def __repr__(self):
# items = []
# if self.shape is not None:
# items.append(f'shape=[{', '.join(map(repr, self.shape))}]')
# if self.strides is not None:
# items.append(f'strides=[{', '.join(map(repr, self.strides))}]')
# if self.dtype is not None:
# items.append(f'dtype={self.dtype}')
# items.append(f'scope={repr(self.scope)}')
# return 'Buffer(' + ', '.join(items) + ')'
class TensorAnnot(BufferAnnot): class TensorAnnot(BufferAnnot):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment