"ts/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "428ade24db13c19e80ae2d443ed4a9545c683907"
Commit 05fc9cd5 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Update `pythonic_expr` to format type casts and improve tensor...

[Enhancement] Update `pythonic_expr` to format type casts and improve tensor validation in Cython wrapper (#581)

- Enhanced `pythonic_expr` to represent type casts as `(type)value` for better clarity in expression representation.
- Modified tensor validation in `CythonKernelWrapper` to conditionally check for tensor contiguity based on a new `skip_tensor_validation` parameter.
- Improved type mapping in `map_torch_type` to include version checks for new float8 types, ensuring compatibility with specific PyTorch versions.
parent 916ee60e
...@@ -140,7 +140,7 @@ cdef class CythonKernelWrapper: ...@@ -140,7 +140,7 @@ cdef class CythonKernelWrapper:
for i in range(len(tensor_list)): for i in range(len(tensor_list)):
tensor = tensor_list[i] tensor = tensor_list[i]
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor):
if not tensor.is_contiguous(): if not skip_tensor_validation and not tensor.is_contiguous():
raise ValueError(f"Input tensor at index {i} must be contiguous") raise ValueError(f"Input tensor at index {i} must be contiguous")
call_args.append(ctypes.c_void_p(tensor.data_ptr())) call_args.append(ctypes.c_void_p(tensor.data_ptr()))
elif isinstance(tensor, int): elif isinstance(tensor, int):
......
...@@ -114,8 +114,11 @@ def pythonic_expr(expr: tvm.tir.Expr) -> str: ...@@ -114,8 +114,11 @@ def pythonic_expr(expr: tvm.tir.Expr) -> str:
# Integer constant: use value directly (ignore type) # Integer constant: use value directly (ignore type)
s = str(node.value) s = str(node.value)
elif isinstance(node, tvm.tir.Cast): elif isinstance(node, tvm.tir.Cast):
# Type cast: skip Cast and use inner value directly # Type cast: represent as (type)value
s = node_to_str_map.get(node.value, str(node.value)) dtype_map = {"int64": "int64_t", "int32": "int32_t", "int8": "int8_t"}
dtype = dtype_map.get(str(node.dtype), str(node.dtype))
value_str = node_to_str_map.get(node.value, str(node.value))
s = f"({dtype}){value_str}"
elif isinstance(node, tvm.tir.Mul): elif isinstance(node, tvm.tir.Mul):
# Multiplication: format as 'left * right' # Multiplication: format as 'left * right'
a_str = node_to_str_map.get(node.a, str(node.a)) a_str = node_to_str_map.get(node.a, str(node.a))
......
...@@ -19,13 +19,21 @@ class TensorSupplyType(Enum): ...@@ -19,13 +19,21 @@ class TensorSupplyType(Enum):
def map_torch_type(intype: str) -> torch.dtype: def map_torch_type(intype: str) -> torch.dtype:
typemap = { if intype == "e4m3_float8":
'e4m3_float8': torch.float8_e4m3fn, assert hasattr(torch, "float8_e4m3fn"), \
'e5m2_float8': torch.float8_e5m2, "torch.float8_e4m3fn is not supported in this version of torch" \
'e4m3fnuz_float8': torch.float8_e4m3fnuz, "Please upgrade torch >= 2.1.0"
} return torch.float8_e4m3fn
if intype in typemap: elif intype == "e5m2_float8":
return typemap[intype] assert hasattr(torch, "float8_e5m2"), \
"torch.float8_e5m2 is not supported in this version of torch" \
"Please upgrade torch >= 2.1.0"
return torch.float8_e5m2
elif intype == "e4m3fnuz_float8":
assert hasattr(torch, "float8_e4m3fnuz"), \
"torch.float8_e4m3fnuz is not supported in this version of torch" \
"Please upgrade torch >= 2.2.0"
return torch.float8_e4m3fnuz
else: else:
return getattr(torch, intype) return getattr(torch, intype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment