Unverified Commit 1768cbef authored by Kurisu's avatar Kurisu Committed by GitHub
Browse files

[Fix] Remove unsupported type params (#1186)

* [Fix] Remove type params

* fix lint error

* [Fix] fix dtype new error
parent 778b97dc
...@@ -56,10 +56,10 @@ jobs: ...@@ -56,10 +56,10 @@ jobs:
run: | run: |
"${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang "${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang
- name: Setup Python 3.12 - name: Setup Python 3.9
uses: actions/setup-python@v6 uses: actions/setup-python@v6
with: with:
python-version: "3.12" python-version: "3.9"
update-environment: true update-environment: true
cache: pip cache: pip
cache-dependency-path: | cache-dependency-path: |
......
...@@ -536,7 +536,8 @@ def get_type_hints(func): ...@@ -536,7 +536,8 @@ def get_type_hints(func):
if annot is None: if annot is None:
raise TypeError(f'Failed to get function type hints, {func} is not a function') raise TypeError(f'Failed to get function type hints, {func} is not a function')
hints = {} hints = {}
type_params = getattr(func, "__type_params__", ()) # type params are not used currently, it is support since python 3.12.4
# type_params = getattr(func, "__type_params__", ())
globalns = getattr(func, '__globals__', {}) globalns = getattr(func, '__globals__', {})
localns = globalns localns = globalns
for name, value in annot.items(): for name, value in annot.items():
...@@ -559,7 +560,8 @@ def get_type_hints(func): ...@@ -559,7 +560,8 @@ def get_type_hints(func):
except Exception: except Exception:
pass pass
value = ForwardRef(value, is_argument=True, is_class=False) value = ForwardRef(value, is_argument=True, is_class=False)
hints[name] = _eval_type(value, globalns=globalns, localns=localns, type_params=type_params) hints[name] = _eval_type(
value, globalns=globalns, localns=localns) #, type_params=type_params)
return hints return hints
......
from tilelang import tvm from tilelang import tvm
from tvm import ir from tvm import ir
import tvm_ffi
import torch import torch
import ctypes import ctypes
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
...@@ -100,16 +99,17 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var ...@@ -100,16 +99,17 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var
return call(expr, is_size_var) return call(expr, is_size_var)
__orig_dtype_new = dtype.__new__
def __dtype_new__(cls, value: AnyDType) -> dtype: def __dtype_new__(cls, value: AnyDType) -> dtype:
if isinstance(value, str): if isinstance(value, str):
val = str.__new__(cls, value) return __orig_dtype_new(cls, value)
elif value in _dtype_py2tvmstr: elif value in _dtype_py2tvmstr:
val = str.__new__(cls, _dtype_py2tvmstr[value]) return __orig_dtype_new(cls, _dtype_py2tvmstr[value])
else: else:
expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values())) expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values()))
raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}") raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}")
val.__tvm_ffi_dtype__ = tvm_ffi.core.DataType(val)
return val
dtype.__eq__ = __dtype_eq__ dtype.__eq__ = __dtype_eq__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment