Commit 05fc9cd5 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Update `pythonic_expr` to format type casts and improve tensor...

[Enhancement] Update `pythonic_expr` to format type casts and improve tensor validation in Cython wrapper (#581)

- Enhanced `pythonic_expr` to represent type casts as `(type)value` for better clarity in expression representation.
- Modified tensor validation in `CythonKernelWrapper` to conditionally check for tensor contiguity based on a new `skip_tensor_validation` parameter.
- Improved type mapping in `map_torch_type` to include version checks for new float8 types, ensuring compatibility with specific PyTorch versions.
parent 916ee60e
......@@ -140,7 +140,7 @@ cdef class CythonKernelWrapper:
for i in range(len(tensor_list)):
tensor = tensor_list[i]
if isinstance(tensor, torch.Tensor):
if not tensor.is_contiguous():
if not skip_tensor_validation and not tensor.is_contiguous():
raise ValueError(f"Input tensor at index {i} must be contiguous")
call_args.append(ctypes.c_void_p(tensor.data_ptr()))
elif isinstance(tensor, int):
......
......@@ -114,8 +114,11 @@ def pythonic_expr(expr: tvm.tir.Expr) -> str:
# Integer constant: use value directly (ignore type)
s = str(node.value)
elif isinstance(node, tvm.tir.Cast):
# Type cast: skip Cast and use inner value directly
s = node_to_str_map.get(node.value, str(node.value))
# Type cast: represent as (type)value
dtype_map = {"int64": "int64_t", "int32": "int32_t", "int8": "int8_t"}
dtype = dtype_map.get(str(node.dtype), str(node.dtype))
value_str = node_to_str_map.get(node.value, str(node.value))
s = f"({dtype}){value_str}"
elif isinstance(node, tvm.tir.Mul):
# Multiplication: format as 'left * right'
a_str = node_to_str_map.get(node.a, str(node.a))
......
......@@ -19,13 +19,21 @@ class TensorSupplyType(Enum):
def map_torch_type(intype: str) -> torch.dtype:
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
'e4m3fnuz_float8': torch.float8_e4m3fnuz,
}
if intype in typemap:
return typemap[intype]
if intype == "e4m3_float8":
assert hasattr(torch, "float8_e4m3fn"), \
"torch.float8_e4m3fn is not supported in this version of torch" \
"Please upgrade torch >= 2.1.0"
return torch.float8_e4m3fn
elif intype == "e5m2_float8":
assert hasattr(torch, "float8_e5m2"), \
"torch.float8_e5m2 is not supported in this version of torch" \
"Please upgrade torch >= 2.1.0"
return torch.float8_e5m2
elif intype == "e4m3fnuz_float8":
assert hasattr(torch, "float8_e4m3fnuz"), \
"torch.float8_e4m3fnuz is not supported in this version of torch" \
"Please upgrade torch >= 2.2.0"
return torch.float8_e4m3fnuz
else:
return getattr(torch, intype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment