"...git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "7a9b93f4b62901c8eab76581300142754b3afd87"
Commit 7b474fbe authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix device type validation for input tensors (#586)

* Enhancement: Update `pythonic_expr` to accept `tvm.tir.PrimExpr` and improve type handling

- Modified the `pythonic_expr` function to check for `tvm.tir.PrimExpr` type, ensuring proper handling of expressions.
- Refactored device and dtype checks in `CythonKernelWrapper` for better clarity and error messaging, enhancing robustness in tensor validation.

* Enhancement: Refine `pythonic_expr` function to support additional expression types

- Updated the `pythonic_expr` function to accept `tvm.tir.PrimExpr` and handle both integer and float immediate types, improving expression representation and type handling.
parent fecc8336
...@@ -66,27 +66,43 @@ cdef class CythonKernelWrapper: ...@@ -66,27 +66,43 @@ cdef class CythonKernelWrapper:
return self return self
cpdef void _check_buffer_device(self, list tensor_list): cpdef void _check_buffer_device(self, list tensor_list):
if isinstance(tensor_list[0], torch.Tensor):
tensor_list_device_type = tensor_list[0].device.type
for param, (buffer_idx, device) in self.buffer_device_map.items(): for param, (buffer_idx, device) in self.buffer_device_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor): tensor = tensor_list[buffer_idx]
tensor_device = tensor_list[buffer_idx].device if isinstance(tensor, torch.Tensor):
if (tensor_list_device_type != device.type or tensor_device = tensor.device
(tensor_device.index is not None and device.index is not None and tensor_device.index != device.index)): device_type_match = device.type == tensor_device.type
raise ValueError(f"Buffer device mismatch for parameter {param}: expected {device}, got {tensor_device}") device_index_match = (
tensor_device.index is None or
device.index is None or
tensor_device.index == device.index
)
if not (device_type_match and device_index_match):
raise ValueError(
f"Buffer device mismatch for parameter {param}: "
f"expected {device}, got {tensor_device}"
)
cpdef void _check_buffer_dtype(self, list tensor_list): cpdef void _check_buffer_dtype(self, list tensor_list):
for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items(): for param, (buffer_idx, torch_dtype) in self.buffer_dtype_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor): tensor = tensor_list[buffer_idx]
if tensor_list[buffer_idx].dtype != torch_dtype: if isinstance(tensor, torch.Tensor) and tensor.dtype != torch_dtype:
raise ValueError(f"Buffer dtype mismatch for parameter {param}: expected {torch_dtype}, got {tensor_list[buffer_idx].dtype}") raise ValueError(
f"Buffer dtype mismatch for parameter {param}: "
f"expected {torch_dtype}, got {tensor.dtype}"
)
cpdef void _check_static_shape(self, list tensor_list): cpdef void _check_static_shape(self, list tensor_list):
for param, (buffer_idx, shape_list) in self.static_shape_map.items(): for param, (buffer_idx, shape_list) in self.static_shape_map.items():
if isinstance(tensor_list[buffer_idx], torch.Tensor): tensor = tensor_list[buffer_idx]
for shape_idx, shape in shape_list: if isinstance(tensor, torch.Tensor):
if tensor_list[buffer_idx].shape[shape_idx] != shape: for shape_idx, expected_shape in shape_list:
raise ValueError(f"Static shape mismatch for parameter {param}: expected {shape} at index {shape_idx}, got {tensor_list[buffer_idx].shape}") actual_shape = tensor.shape[shape_idx]
if actual_shape != expected_shape:
raise ValueError(
f"Static shape mismatch for parameter {param}: "
f"expected {expected_shape} at index {shape_idx}, "
f"got {actual_shape}"
)
cpdef forward(self, list inputs, int64_t stream = -1, bint skip_tensor_validation = False): cpdef forward(self, list inputs, int64_t stream = -1, bint skip_tensor_validation = False):
# Validate input dimensions and prepare for kernel execution # Validate input dimensions and prepare for kernel execution
...@@ -149,7 +165,7 @@ cdef class CythonKernelWrapper: ...@@ -149,7 +165,7 @@ cdef class CythonKernelWrapper:
call_args = [] call_args = []
for i, tensor in enumerate(tensor_list): for i, tensor in enumerate(tensor_list):
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor):
if not skip_tensor_validation and not tensor.is_contiguous(): if not tensor.is_contiguous():
raise ValueError(f"Input tensor at index {i} must be contiguous") raise ValueError(f"Input tensor at index {i} must be contiguous")
call_args.append(ctypes.c_void_p(tensor.data_ptr())) call_args.append(ctypes.c_void_p(tensor.data_ptr()))
elif isinstance(tensor, (int, float, bool)): elif isinstance(tensor, (int, float, bool)):
......
...@@ -103,14 +103,16 @@ def get_annotated_mod( ...@@ -103,14 +103,16 @@ def get_annotated_mod(
return dispatch[model_type](mod) return dispatch[model_type](mod)
def pythonic_expr(expr: tvm.tir.Expr) -> str: def pythonic_expr(expr: tvm.tir.PrimExpr) -> str:
if not isinstance(expr, tvm.tir.PrimExpr):
return str(expr)
python_str = "" python_str = ""
node_to_str_map = {} # Stores string representation for each node node_to_str_map = {} # Stores string representation for each node
def _pythonic_visitor(node): def _pythonic_visitor(node):
if isinstance(node, tvm.tir.Var): if isinstance(node, tvm.tir.Var):
s = node.name s = node.name
elif isinstance(node, tvm.tir.IntImm): elif isinstance(node, (tvm.tir.IntImm, tvm.tir.FloatImm)):
# Integer constant: use value directly (ignore type) # Integer constant: use value directly (ignore type)
s = str(node.value) s = str(node.value)
elif isinstance(node, tvm.tir.Cast): elif isinstance(node, tvm.tir.Cast):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment