Commit 7b474fbe authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix device type validation for input tensors (#586)

* Enhancement: Update `pythonic_expr` to accept `tvm.tir.PrimExpr` and improve type handling

- Modified the `pythonic_expr` function to check for `tvm.tir.PrimExpr` type, ensuring proper handling of expressions.
- Refactored device and dtype checks in `CythonKernelWrapper` for better clarity and error messaging, enhancing robustness in tensor validation.

* Enhancement: Refine `pythonic_expr` function to support additional expression types

- Updated the `pythonic_expr` function to accept `tvm.tir.PrimExpr` and handle both integer and float immediate types, improving expression representation and type handling.
parent fecc8336
......@@ -66,27 +66,43 @@ cdef class CythonKernelWrapper:
return self
cpdef void _check_buffer_device(self, list tensor_list):
if isinstance(tensor_list[0], torch.Tensor):
tensor_list_device_type = tensor_list[0].device.type
for param, (buffer_idx, device) in self.buffer_device_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
tensor_device = tensor_list[buffer_idx].device
if (tensor_list_device_type != device.type or
(tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)):
raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}")
tensor = tensor_list[buffer_idx]
if isinstance(tensor, torch.Tensor):
tensor_device = tensor.device
device_type_match = device.type == tensor_device.type
device_index_match = (
tensor_device.index is None or
device.index is None or
tensor_device.index == device.index
)
if not (device_type_match and device_index_match):
raise ValueError(
f"Buffer device mismatch for parameter {param}: "
f"expected {device}, got {tensor_device}"
)
cpdef void _check_buffer_dtype(self, list tensor_list):
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
if tensor_list[buffer_idx].dtype != torch_dtype:
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}")
tensor = tensor_list[buffer_idx]
if isinstance(tensor, torch.Tensor) and tensor.dtype != torch_dtype:
raise ValueError(
f"Buffer dtype mismatch for parameter {param}: "
f"expected {torch_dtype}, got {tensor.dtype}"
)
cpdef void _check_static_shape(self, list tensor_list):
for param, (buffer_idx, shape_list) in self.static_shape_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor):
for shape_idx, shape in shape_list:
if tensor_list[buffer_idx].shape[shape_idx] != shape:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape} at index {shape_idx}, got {tensor_list[buffer_idx].shape}")
tensor = tensor_list[buffer_idx]
if isinstance(tensor, torch.Tensor):
for shape_idx, expected_shape in shape_list:
actual_shape = tensor.shape[shape_idx]
if actual_shape != expected_shape:
raise ValueError(
f"Static shape mismatch for parameter {param}: "
f"expected {expected_shape} at index {shape_idx}, "
f"got {actual_shape}"
)
cpdef forward(self, list inputs, int64_t stream = -1, bint skip_tensor_validation = False):
# Validate input dimensions and prepare for kernel execution
......@@ -149,7 +165,7 @@ cdef class CythonKernelWrapper:
call_args = []
for i, tensor in enumerate(tensor_list):
if isinstance(tensor, torch.Tensor):
if not skip_tensor_validation and not tensor.is_contiguous():
if not tensor.is_contiguous():
raise ValueError(f"Input tensor at index {i} must be contiguous")
call_args.append(ctypes.c_void_p(tensor.data_ptr()))
elif isinstance(tensor, (int, float, bool)):
......
......@@ -103,14 +103,16 @@ def get_annotated_mod(
return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.Expr) -> str:
def pythonic_expr(expr: tvm.tir.PrimExpr) -> str:
if not isinstance(expr, tvm.tir.PrimExpr):
return str(expr)
python_str = ""
node_to_str_map = {} # Stores string representation for each node
def _pythonic_visitor(node):
if isinstance(node, tvm.tir.Var):
s = node.name
elif isinstance(node, tvm.tir.IntImm):
elif isinstance(node, (tvm.tir.IntImm, tvm.tir.FloatImm)):
# Integer constant: use value directly (ignore type)
s = str(node.value)
elif isinstance(node, tvm.tir.Cast):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment