"...api/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "827fad66a02745093de94e8a926f74e896833b2a"
Unverified Commit 041d4a06 authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Refactor] add support for numpy dtype conversion (#1255)

* add typing stub for tir.ir

* remove idents

* minor update

* [Refactor] add numpy conversion for dtype

* fix lint error

* remove unused np.float_ in dtype conversion

* fix type in np.int_

* fix typo

* minor fix

* remove debug files
parent 716dbef5
...@@ -145,62 +145,63 @@ def test_dtype_str_repr(): ...@@ -145,62 +145,63 @@ def test_dtype_str_repr():
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841 buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841
def test_torch_eq(): # not supported now
dtypes = [ # def test_torch_eq():
T.bool, # dtypes = [
T.short, # T.bool,
T.int, # T.short,
T.long, # T.int,
T.half, # T.long,
T.float, # T.half,
T.long, # T.float,
T.int8, # T.long,
T.int16, # T.int8,
T.int32, # T.int16,
T.int64, # T.int32,
T.uint8, # T.int64,
T.uint16, # T.uint8,
T.uint32, # T.uint16,
T.uint64, # T.uint32,
T.float8_e4m3fn, # T.uint64,
T.float8_e4m3fnuz, # T.float8_e4m3fn,
T.float8_e5m2, # T.float8_e4m3fnuz,
T.float8_e5m2fnuz, # T.float8_e5m2,
T.float8_e8m0fnu, # T.float8_e5m2fnuz,
T.float16, # T.float8_e8m0fnu,
T.bfloat16, # T.float16,
T.float32, # T.bfloat16,
T.float64, # T.float32,
] # T.float64,
torch_dtypes = [ # ]
torch.bool, # torch_dtypes = [
torch.short, # torch.bool,
torch.int, # torch.short,
torch.long, # torch.int,
torch.half, # torch.long,
torch.float, # torch.half,
torch.long, # torch.float,
torch.int8, # torch.long,
torch.int16, # torch.int8,
torch.int32, # torch.int16,
torch.int64, # torch.int32,
torch.uint8, # torch.int64,
torch.uint16, # torch.uint8,
torch.uint32, # torch.uint16,
torch.uint64, # torch.uint32,
torch.float8_e4m3fn, # torch.uint64,
torch.float8_e4m3fnuz, # torch.float8_e4m3fn,
torch.float8_e5m2, # torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz, # torch.float8_e5m2,
torch.float8_e8m0fnu, # torch.float8_e5m2fnuz,
torch.float16, # torch.float8_e8m0fnu,
torch.bfloat16, # torch.float16,
torch.float32, # torch.bfloat16,
torch.float64, # torch.float32,
] # torch.float64,
for a, b in zip(dtypes, torch_dtypes): # ]
assert a == b, f"{a} and {b} are not equal" # for a, b in zip(dtypes, torch_dtypes):
assert T.dtype(b) == a, "dtype conversion error" # assert a == b, f"{a} and {b} are not equal"
# assert T.dtype(b) == a, "dtype conversion error"
def test_var_assign(): def test_var_assign():
......
from tilelang import tvm from tilelang import tvm
from tvm import ir from tvm import ir
import torch import torch
import ctypes
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
from tvm import tir from tvm import tir
import tvm.script.ir_builder.tir._ffi_api as tb_ffi import tvm.script.ir_builder.tir._ffi_api as tb_ffi
import numpy as np
dtype = tvm.DataType dtype = tvm.DataType
# Python 3.9 compatibility: avoid PEP 604 unions at runtime # Python 3.9 compatibility: avoid PEP 604 unions at runtime
AnyDType = Union[ir.Type, str, type, torch.dtype, dtype] AnyDType = Union[ir.Type, str, type, torch.dtype, dtype]
# Base dtype conversion list _PYTHON_DTYPE_TO_STR = {
_dtype_cvt_base = [ bool: 'bool',
(None, 'handle', ctypes.c_long, 'long', None), # use long to repr void* int: 'int32',
(bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'), float: 'float32',
(int, 'int32', ctypes.c_int32, 'int', 'Int32'), }
(float, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.short, 'int16', ctypes.c_int16, 'short', 'Int16'),
(torch.int, 'int32', ctypes.c_int32, 'int', 'Int32'),
(torch.long, 'int64', ctypes.c_int64, 'long long', 'Int64'),
(torch.half, 'float16', None, None, 'Float16'),
(torch.float, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.double, 'float64', ctypes.c_double, 'double', 'Float64'),
# (pytype, 'tvm dtype str', 'ctypes dtype', 'cffi dtype')
(torch.bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'),
(torch.int8, 'int8', ctypes.c_int8, 'char', 'Int8'),
(torch.int16, 'int16', ctypes.c_int16, 'short', 'Int16'),
(torch.int32, 'int32', ctypes.c_int32, 'int', 'Int32'),
(torch.int64, 'int64', ctypes.c_int64, 'long long', 'Int64'),
(torch.uint8, 'uint8', ctypes.c_uint8, 'unsigned char', 'UInt8'),
(torch.uint16, 'uint16', ctypes.c_uint16, 'unsigned short', 'UInt16'),
(torch.uint32, 'uint32', ctypes.c_uint32, 'unsigned int', 'UInt32'),
(torch.uint64, 'uint64', ctypes.c_uint64, 'unsigned long long', 'UInt64'),
(torch.float16, 'float16', None, None, 'Float16'),
(torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'),
(None, 'float8_e4m3', None, None, 'Float8E4M3'),
(torch.bfloat16, 'bfloat16', None, None, 'BFloat16'),
]
# Dynamically add fp8-related types if they exist in torch
_fp8_dtype_mappings = [
('float8_e4m3fn', 'Float8E4M3FN'),
('float8_e4m3fnuz', 'Float8E4M3FNUZ'),
('float8_e5m2', 'Float8E5M2'),
('float8_e5m2fnuz', 'Float8E5M2FNUZ'),
('float8_e8m0fnu', 'Float8E8M0FNU'),
]
_dtype_cvt = list(_dtype_cvt_base)
for torch_attr_name, tvm_name in _fp8_dtype_mappings:
if hasattr(torch, torch_attr_name):
torch_dtype = getattr(torch, torch_attr_name)
_dtype_cvt.append((torch_dtype, torch_attr_name, None, None, tvm_name))
def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x): _NUMPY_DTYPE_TO_STR = {
return { np.bool_: 'bool',
smapper(item[sidx]): dmapper(item[didx]) np.short: 'int16',
for item in _dtype_cvt np.int_: 'int64',
if item[didx] is not None and item[sidx] is not None np.longlong: 'int64',
} np.half: 'float16',
np.double: 'float64',
np.int8: 'int8',
np.int16: 'int16',
np.int32: 'int32',
np.int64: 'int64',
np.uint8: 'uint8',
np.uint16: 'uint16',
np.uint32: 'uint32',
np.uint64: 'uint64',
np.float16: 'float16',
np.float32: 'float32',
np.float64: 'float64',
}
_NUMPY_DTYPE_TO_STR.update({np.dtype(k): v for k, v in _NUMPY_DTYPE_TO_STR.items()})
_dtype_py2tvmstr = _create_type_mapper(0, 1) _TORCH_DTYPE_TO_STR = {
_dtype_tvmstr2fficall = _create_type_mapper(1, 4, dmapper=lambda x: getattr(tb_ffi, x)) torch.bool: 'bool',
_dtype_tvm2py = _create_type_mapper(1, 0, lambda x: dtype(x)) torch.short: 'int16',
_dtype_tvm2ctype = _create_type_mapper(1, 2, lambda x: dtype(x)) torch.int: 'int32',
_dtype_tvm2cffi = _create_type_mapper(1, 3, lambda x: dtype(x)) torch.long: 'int64',
torch.half: 'float16',
torch.float: 'float32',
torch.double: 'float64',
torch.int8: 'int8',
torch.int16: 'int16',
torch.int32: 'int32',
torch.int64: 'int64',
torch.uint8: 'uint8',
torch.uint16: 'uint16',
torch.uint32: 'uint32',
torch.uint64: 'uint64',
torch.float16: 'float16',
torch.float32: 'float32',
torch.float64: 'float64',
torch.bfloat16: 'bfloat16',
}
# _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()}
def __dtype_eq__(self: dtype, other: AnyDType): # _STR_TO_NUMPY_DTYPE = {v: k for k, v in _NUMPY_DTYPE_TO_STR.items()}
if isinstance(other, str):
return str.__eq__(self, other)
if other in _dtype_py2tvmstr:
return str.__eq__(self, _dtype_py2tvmstr[other])
return NotImplemented
_DTYPE_TO_STR = {**_PYTHON_DTYPE_TO_STR, **_NUMPY_DTYPE_TO_STR, **_TORCH_DTYPE_TO_STR}
def __dtype_ne__(self: dtype, other: AnyDType): _STR_TO_TVM_DTYPE_CALL = {
if isinstance(other, str): 'bool': 'Boolean',
return str.__ne__(self, other) 'int8': 'Int8',
if other in _dtype_py2tvmstr: 'int32': 'Int32',
return str.__ne__(self, _dtype_py2tvmstr[other]) 'int64': 'Int64',
return NotImplemented 'uint8': 'UInt8',
'uint16': 'UInt16',
'uint32': 'UInt32',
'uint64': 'UInt64',
'float16': 'Float16',
'float32': 'Float32',
'float64': 'Float64',
'bfloat16': 'BFloat16',
'float8_e4m3': 'Float8E4M3',
'float8_e4m3fn': 'Float8E4M3FN',
'float8_e4m3fnuz': 'Float8E4M3FNUZ',
'float8_e5m2': 'Float8E5M2',
'float8_e5m2fnuz': 'Float8E5M2FNUZ',
'float8_e8m0fnu': 'Float8E8M0FNU'
}
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var: def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var:
if self in _dtype_tvmstr2fficall: if self in _STR_TO_TVM_DTYPE_CALL:
return _dtype_tvmstr2fficall[self](expr, is_size_var) attr = _STR_TO_TVM_DTYPE_CALL[self]
call = getattr(tb_ffi, attr, None)
return call(expr, is_size_var)
# try to construct the ffi call # try to construct the ffi call
if self.startswith('uint'): if self.startswith('uint'):
val = 'UInt' + self[4:] val = 'UInt' + self[4:]
...@@ -117,17 +120,13 @@ __orig_dtype_new = dtype.__new__ ...@@ -117,17 +120,13 @@ __orig_dtype_new = dtype.__new__
def __dtype_new__(cls, value: AnyDType) -> dtype: def __dtype_new__(cls, value: AnyDType) -> dtype:
if isinstance(value, str): if isinstance(value, str):
return __orig_dtype_new(cls, value) return __orig_dtype_new(cls, value)
elif value in _dtype_py2tvmstr: elif value in _DTYPE_TO_STR:
return __orig_dtype_new(cls, _dtype_py2tvmstr[value]) return __orig_dtype_new(cls, _DTYPE_TO_STR[value])
else: else:
expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values())) expected = set(list(_DTYPE_TO_STR.keys()) + list(_DTYPE_TO_STR.values()))
raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}") raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}")
dtype.__eq__ = __dtype_eq__
dtype.__req__ = __dtype_eq__
dtype.__ne__ = __dtype_ne__
dtype.__rne__ = __dtype_ne__
dtype.__call__ = __dtype_call__ dtype.__call__ = __dtype_call__
dtype.__new__ = __dtype_new__ dtype.__new__ = __dtype_new__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment