"examples/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "54fc6ba099e2c0081ddf76d33bc111c928f45f17"
Unverified Commit c67d66a3 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Dynamic registration of FP8 data type for compatibility with older...

[Refactor] Dynamic registration of FP8 data type for compatibility with older PyTorch versions (#1197)
parent a9d823b8
...@@ -12,7 +12,7 @@ parser.add_argument('--batch', type=int, default=128, help='batch size') ...@@ -12,7 +12,7 @@ parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=16, help='heads') parser.add_argument('--heads', type=int, default=16, help='heads')
parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length') parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length') parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=512, help='dim') parser.add_argument('--dim', type=int, default=256, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal') parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--use_v2", action="store_true") parser.add_argument("--use_v2", action="store_true")
......
...@@ -10,7 +10,8 @@ dtype = tvm.DataType ...@@ -10,7 +10,8 @@ dtype = tvm.DataType
# Python 3.9 compatibility: avoid PEP 604 unions at runtime # Python 3.9 compatibility: avoid PEP 604 unions at runtime
AnyDType = Union[ir.Type, str, type, torch.dtype, dtype] AnyDType = Union[ir.Type, str, type, torch.dtype, dtype]
_dtype_cvt = [ # Base dtype conversion list
_dtype_cvt_base = [
(None, 'handle', ctypes.c_long, 'long', None), # use long to repr void* (None, 'handle', ctypes.c_long, 'long', None), # use long to repr void*
(bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'), (bool, 'bool', ctypes.c_bool, 'bool', 'Boolean'),
(int, 'int32', ctypes.c_int32, 'int', 'Int32'), (int, 'int32', ctypes.c_int32, 'int', 'Int32'),
...@@ -36,14 +37,24 @@ _dtype_cvt = [ ...@@ -36,14 +37,24 @@ _dtype_cvt = [
(torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'), (torch.float32, 'float32', ctypes.c_float, 'float', 'Float32'),
(torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'), (torch.float64, 'float64', ctypes.c_double, 'double', 'Float64'),
(None, 'float8_e4m3', None, None, 'Float8E4M3'), (None, 'float8_e4m3', None, None, 'Float8E4M3'),
(torch.float8_e4m3fn, 'float8_e4m3fn', None, None, 'Float8E4M3FN'),
(torch.float8_e4m3fnuz, 'float8_e4m3fnuz', None, None, 'Float8E4M3FNUZ'),
(torch.float8_e5m2, 'float8_e5m2', None, None, 'Float8E5M2'),
(torch.float8_e5m2fnuz, 'float8_e5m2fnuz', None, None, 'Float8E5M2FNUZ'),
(torch.float8_e8m0fnu, 'float8_e8m0fnu', None, None, 'Float8E8M0FNU'),
(torch.bfloat16, 'bfloat16', None, None, 'BFloat16'), (torch.bfloat16, 'bfloat16', None, None, 'BFloat16'),
] ]
# Dynamically add fp8-related types if they exist in torch
_fp8_dtype_mappings = [
('float8_e4m3fn', 'Float8E4M3FN'),
('float8_e4m3fnuz', 'Float8E4M3FNUZ'),
('float8_e5m2', 'Float8E5M2'),
('float8_e5m2fnuz', 'Float8E5M2FNUZ'),
('float8_e8m0fnu', 'Float8E8M0FNU'),
]
_dtype_cvt = list(_dtype_cvt_base)
for torch_attr_name, tvm_name in _fp8_dtype_mappings:
if hasattr(torch, torch_attr_name):
torch_dtype = getattr(torch, torch_attr_name)
_dtype_cvt.append((torch_dtype, torch_attr_name, None, None, tvm_name))
def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x): def _create_type_mapper(sidx, didx, smapper=lambda x: x, dmapper=lambda x: x):
return { return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment