Unverified Commit 041d4a06 authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Refactor] add support for numpy dtype conversion (#1255)

* add typing stub for tir.ir

* remove idents

* minor update

* [Refactor] add numpy conversion for dtype

* fix lint error

* remove unused np.float_ in dtype conversion

* fix type in np.int_

* fix typo

* minor fix

* remove debug files
parent 716dbef5
......@@ -145,62 +145,63 @@ def test_dtype_str_repr():
buf_24 = T.alloc_buffer((1,), dtype=T.float64, scope='shared') # noqa F841
def test_torch_eq():
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
assert T.dtype(b) == a, "dtype conversion error"
# not supported now
# def test_torch_eq():
# dtypes = [
# T.bool,
# T.short,
# T.int,
# T.long,
# T.half,
# T.float,
# T.long,
# T.int8,
# T.int16,
# T.int32,
# T.int64,
# T.uint8,
# T.uint16,
# T.uint32,
# T.uint64,
# T.float8_e4m3fn,
# T.float8_e4m3fnuz,
# T.float8_e5m2,
# T.float8_e5m2fnuz,
# T.float8_e8m0fnu,
# T.float16,
# T.bfloat16,
# T.float32,
# T.float64,
# ]
# torch_dtypes = [
# torch.bool,
# torch.short,
# torch.int,
# torch.long,
# torch.half,
# torch.float,
# torch.long,
# torch.int8,
# torch.int16,
# torch.int32,
# torch.int64,
# torch.uint8,
# torch.uint16,
# torch.uint32,
# torch.uint64,
# torch.float8_e4m3fn,
# torch.float8_e4m3fnuz,
# torch.float8_e5m2,
# torch.float8_e5m2fnuz,
# torch.float8_e8m0fnu,
# torch.float16,
# torch.bfloat16,
# torch.float32,
# torch.float64,
# ]
# for a, b in zip(dtypes, torch_dtypes):
# assert a == b, f"{a} and {b} are not equal"
# assert T.dtype(b) == a, "dtype conversion error"
def test_var_assign():
......
from tilelang import tvm
from tvm import ir
import torch
import ctypes
from typing import TYPE_CHECKING, Union
from tvm import tir
import tvm.script.ir_builder.tir._ffi_api as tb_ffi
import numpy as np
dtype = tvm.DataType
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
AnyDType = Union[ir.Type, str, type, torch.dtype, dtype]
# Base dtype conversion list
_dtype_cvt_base = [
(None, 'handle', ctypes.c_long, 'long', None), # use long to repr void*
(bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'),
(int, 'int32', ctypes.c_int32, 'int', 'Int32'),
(float, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.short, 'int16', ctypes.c_int16, 'short', 'Int16'),
(torch.int, 'int32', ctypes.c_int32, 'int', 'Int32'),
(torch.long, 'int64', ctypes.c_int64, 'long long', 'Int64'),
(torch.half, 'float16', None, None, 'Float16'),
(torch.float, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.double, 'float64', ctypes.c_double, 'double', 'Float64'),
# (pytype, 'tvm dtype str', 'ctypes dtype', 'cffi dtype')
(torch.bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'),
(torch.int8, 'int8', ctypes.c_int8, 'char', 'Int8'),
(torch.int16, 'int16', ctypes.c_int16, 'short', 'Int16'),
(torch.int32, 'int32', ctypes.c_int32, 'int', 'Int32'),
(torch.int64, 'int64', ctypes.c_int64, 'long long', 'Int64'),
(torch.uint8, 'uint8', ctypes.c_uint8, 'unsigned char', 'UInt8'),
(torch.uint16, 'uint16', ctypes.c_uint16, 'unsigned short', 'UInt16'),
(torch.uint32, 'uint32', ctypes.c_uint32, 'unsigned int', 'UInt32'),
(torch.uint64, 'uint64', ctypes.c_uint64, 'unsigned long long', 'UInt64'),
(torch.float16, 'float16', None, None, 'Float16'),
(torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'),
(None, 'float8_e4m3', None, None, 'Float8E4M3'),
(torch.bfloat16, 'bfloat16', None, None, 'BFloat16'),
]
# Dynamically add fp8-related types if they exist in torch
_fp8_dtype_mappings = [
('float8_e4m3fn', 'Float8E4M3FN'),
('float8_e4m3fnuz', 'Float8E4M3FNUZ'),
('float8_e5m2', 'Float8E5M2'),
('float8_e5m2fnuz', 'Float8E5M2FNUZ'),
('float8_e8m0fnu', 'Float8E8M0FNU'),
]
_dtype_cvt = list(_dtype_cvt_base)
for torch_attr_name, tvm_name in _fp8_dtype_mappings:
if hasattr(torch, torch_attr_name):
torch_dtype = getattr(torch, torch_attr_name)
_dtype_cvt.append((torch_dtype, torch_attr_name, None, None, tvm_name))
_PYTHON_DTYPE_TO_STR = {
bool: 'bool',
int: 'int32',
float: 'float32',
}
def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x):
return {
smapper(item[sidx]): dmapper(item[didx])
for item in _dtype_cvt
if item[didx] is not None and item[sidx] is not None
}
_NUMPY_DTYPE_TO_STR = {
np.bool_: 'bool',
np.short: 'int16',
np.int_: 'int64',
np.longlong: 'int64',
np.half: 'float16',
np.double: 'float64',
np.int8: 'int8',
np.int16: 'int16',
np.int32: 'int32',
np.int64: 'int64',
np.uint8: 'uint8',
np.uint16: 'uint16',
np.uint32: 'uint32',
np.uint64: 'uint64',
np.float16: 'float16',
np.float32: 'float32',
np.float64: 'float64',
}
_NUMPY_DTYPE_TO_STR.update({np.dtype(k): v for k, v in _NUMPY_DTYPE_TO_STR.items()})
_dtype_py2tvmstr = _create_type_mapper(0, 1)
_dtype_tvmstr2fficall = _create_type_mapper(1, 4, dmapper=lambda x: getattr(tb_ffi, x))
_dtype_tvm2py = _create_type_mapper(1, 0, lambda x: dtype(x))
_dtype_tvm2ctype = _create_type_mapper(1, 2, lambda x: dtype(x))
_dtype_tvm2cffi = _create_type_mapper(1, 3, lambda x: dtype(x))
_TORCH_DTYPE_TO_STR = {
torch.bool: 'bool',
torch.short: 'int16',
torch.int: 'int32',
torch.long: 'int64',
torch.half: 'float16',
torch.float: 'float32',
torch.double: 'float64',
torch.int8: 'int8',
torch.int16: 'int16',
torch.int32: 'int32',
torch.int64: 'int64',
torch.uint8: 'uint8',
torch.uint16: 'uint16',
torch.uint32: 'uint32',
torch.uint64: 'uint64',
torch.float16: 'float16',
torch.float32: 'float32',
torch.float64: 'float64',
torch.bfloat16: 'bfloat16',
}
# _STR_TO_TORCH_DTYPE = {v: k for k, v in _TORCH_DTYPE_TO_STR.items()}
def __dtype_eq__(self: dtype, other: AnyDType):
if isinstance(other, str):
return str.__eq__(self, other)
if other in _dtype_py2tvmstr:
return str.__eq__(self, _dtype_py2tvmstr[other])
return NotImplemented
# _STR_TO_NUMPY_DTYPE = {v: k for k, v in _NUMPY_DTYPE_TO_STR.items()}
_DTYPE_TO_STR = {**_PYTHON_DTYPE_TO_STR, **_NUMPY_DTYPE_TO_STR, **_TORCH_DTYPE_TO_STR}
def __dtype_ne__(self: dtype, other: AnyDType):
if isinstance(other, str):
return str.__ne__(self, other)
if other in _dtype_py2tvmstr:
return str.__ne__(self, _dtype_py2tvmstr[other])
return NotImplemented
_STR_TO_TVM_DTYPE_CALL = {
'bool': 'Boolean',
'int8': 'Int8',
'int32': 'Int32',
'int64': 'Int64',
'uint8': 'UInt8',
'uint16': 'UInt16',
'uint32': 'UInt32',
'uint64': 'UInt64',
'float16': 'Float16',
'float32': 'Float32',
'float64': 'Float64',
'bfloat16': 'BFloat16',
'float8_e4m3': 'Float8E4M3',
'float8_e4m3fn': 'Float8E4M3FN',
'float8_e4m3fnuz': 'Float8E4M3FNUZ',
'float8_e5m2': 'Float8E5M2',
'float8_e5m2fnuz': 'Float8E5M2FNUZ',
'float8_e8m0fnu': 'Float8E8M0FNU'
}
def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var:
if self in _dtype_tvmstr2fficall:
return _dtype_tvmstr2fficall[self](expr, is_size_var)
if self in _STR_TO_TVM_DTYPE_CALL:
attr = _STR_TO_TVM_DTYPE_CALL[self]
call = getattr(tb_ffi, attr, None)
return call(expr, is_size_var)
# try to construct the ffi call
if self.startswith('uint'):
val = 'UInt' + self[4:]
......@@ -117,17 +120,13 @@ __orig_dtype_new = dtype.__new__
def __dtype_new__(cls, value: AnyDType) -> dtype:
if isinstance(value, str):
return __orig_dtype_new(cls, value)
elif value in _dtype_py2tvmstr:
return __orig_dtype_new(cls, _dtype_py2tvmstr[value])
elif value in _DTYPE_TO_STR:
return __orig_dtype_new(cls, _DTYPE_TO_STR[value])
else:
expected = set(list(_dtype_py2tvmstr.keys()) + list(_dtype_tvmstr2fficall.values()))
expected = set(list(_DTYPE_TO_STR.keys()) + list(_DTYPE_TO_STR.values()))
raise TypeError(f"Invalid DataType {value}({type(value)}), expect one of {expected}")
dtype.__eq__ = __dtype_eq__
dtype.__req__ = __dtype_eq__
dtype.__ne__ = __dtype_ne__
dtype.__rne__ = __dtype_ne__
dtype.__call__ = __dtype_call__
dtype.__new__ = __dtype_new__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment