Commit 44508e59 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Update dtype handling in KernelParam and CythonKernelWrapper (#582)

- Modified `KernelParam.from_var` to map Torch data types to a more appropriate format.
- Enhanced `CythonKernelWrapper` to support additional tensor types and ensure proper conversion of tensor dtypes to C types, improving error handling for unsupported types.
parent 05fc9cd5
......@@ -54,7 +54,8 @@ class KernelParam:
Returns:
KernelParam instance representing a scalar (empty shape)
"""
return cls(var.dtype, [])
dtype = map_torch_type(var.dtype)
return cls(dtype, [])
def is_scalar(self) -> bool:
"""
......
......@@ -134,25 +134,32 @@ cdef class CythonKernelWrapper:
tensor = inputs[ins_idx]
ins_idx += 1
tensor_list.append(tensor)
# Convert tensor pointers to C void pointers for kernel call
cdef dict dtype_to_ctype = {
torch.float16: ctypes.c_float,
torch.float32: ctypes.c_float,
torch.float64: ctypes.c_double,
torch.int8: ctypes.c_int8,
torch.int16: ctypes.c_int16,
torch.int32: ctypes.c_int32,
torch.int64: ctypes.c_int64,
}
call_args = []
for i in range(len(tensor_list)):
tensor = tensor_list[i]
for i, tensor in enumerate(tensor_list):
if isinstance(tensor, torch.Tensor):
if not skip_tensor_validation and not tensor.is_contiguous():
raise ValueError(f"Input tensor at index {i} must be contiguous")
call_args.append(ctypes.c_void_p(tensor.data_ptr()))
elif isinstance(tensor, int):
# Dynamic symbolics which are passed as integer arguments
elif isinstance(tensor, (int, float, bool)):
if i in self.ptr_map:
call_args.append(ctypes.c_void_p(tensor))
else:
call_args.append(tensor)
elif isinstance(tensor, float):
call_args.append(ctypes.c_float(tensor))
elif isinstance(tensor, bool):
call_args.append(ctypes.c_bool(tensor))
dtype = self.param_dtypes[i]
if dtype not in dtype_to_ctype:
raise ValueError(f"Unsupported tensor dtype: {dtype}")
call_args.append(dtype_to_ctype[dtype](tensor))
else:
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment