Unverified Commit 6021f863 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language V2] Minor fix for complex annotations (#1381)

parent 8f50c122
...@@ -102,9 +102,6 @@ class Value(Annot): ...@@ -102,9 +102,6 @@ class Value(Annot):
return Value(kind='static', name=prefer_name, dtype=dt.int32, value=value) return Value(kind='static', name=prefer_name, dtype=dt.int32, value=value)
elif isinstance(value, float): elif isinstance(value, float):
return Value(kind='static', name=prefer_name, dtype=dt.float32, value=value) return Value(kind='static', name=prefer_name, dtype=dt.float32, value=value)
elif isinstance(value, tir.Var):
# handle A: T.Tensor[[M, N, K], ...]
return Value(kind='dynamic', name=value.name, dtype=value.dtype, value=value)
elif isinstance(value, dt.dtype): elif isinstance(value, dt.dtype):
# handle A: T.float32 # handle A: T.float32
return Value(kind='dynamic', name=prefer_name, dtype=value, value=None) return Value(kind='dynamic', name=prefer_name, dtype=value, value=None)
...@@ -113,6 +110,11 @@ class Value(Annot): ...@@ -113,6 +110,11 @@ class Value(Annot):
return value return value
elif isinstance(value, TypeVar): elif isinstance(value, TypeVar):
return Value(kind='static', name=value.__name__, value=None) return Value(kind='static', name=value.__name__, value=None)
elif isinstance(value, (tir.Var, PrimExpr)):
# handle A: T.Tensor[[M, N, K], ...]
# or primexpr annotation like A: T.Tensor[[M, N * 4 +1]]
name = value.name if isinstance(value, tir.Var) else prefer_name
return Value(kind='dynamic', name=name, dtype=value.dtype, value=value)
elif value is Any or value is None or value is dt.dtype or isinstance( elif value is Any or value is None or value is dt.dtype or isinstance(
value, (type, _GenericAlias)): value, (type, _GenericAlias)):
# A # no annotation # A # no annotation
...@@ -122,7 +124,7 @@ class Value(Annot): ...@@ -122,7 +124,7 @@ class Value(Annot):
# A: tuple[...] # A: tuple[...]
return Value(kind='static', name=prefer_name, value=None) return Value(kind='static', name=prefer_name, value=None)
else: else:
raise TypeError(f"Unsupported Value annotation: {value!r}") raise TypeError(f"Unsupported Value annotation: {value!r}, type: {type(value)}")
def with_name(self, name: str) -> Value: def with_name(self, name: str) -> Value:
return Value(kind=self.kind, name=self.name or name, dtype=self.dtype, value=self.value) return Value(kind=self.kind, name=self.name or name, dtype=self.dtype, value=self.value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment