"src/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "ede9eaa3c493d42eab8cb3e749fcf27e9ace60d2"
Commit 44508e59 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Update dtype handling in KernelParam and CythonKernelWrapper (#582)

- Modified `KernelParam.from_var` to map Torch data types to a more appropriate format.
- Enhanced `CythonKernelWrapper` to support additional tensor types and ensure proper conversion of tensor dtypes to C types, improving error handling for unsupported types.
parent 05fc9cd5
...@@ -54,7 +54,8 @@ class KernelParam: ...@@ -54,7 +54,8 @@ class KernelParam:
Returns: Returns:
KernelParam instance representing a scalar (empty shape) KernelParam instance representing a scalar (empty shape)
""" """
return cls(var.dtype, []) dtype = map_torch_type(var.dtype)
return cls(dtype, [])
def is_scalar(self) -> bool: def is_scalar(self) -> bool:
""" """
......
...@@ -134,25 +134,32 @@ cdef class CythonKernelWrapper: ...@@ -134,25 +134,32 @@ cdef class CythonKernelWrapper:
tensor = inputs[ins_idx] tensor = inputs[ins_idx]
ins_idx += 1 ins_idx += 1
tensor_list.append(tensor) tensor_list.append(tensor)
# Convert tensor pointers to C void pointers for kernel call # Convert tensor pointers to C void pointers for kernel call
cdef dict dtype_to_ctype = {
torch.float16: ctypes.c_float,
torch.float32: ctypes.c_float,
torch.float64: ctypes.c_double,
torch.int8: ctypes.c_int8,
torch.int16: ctypes.c_int16,
torch.int32: ctypes.c_int32,
torch.int64: ctypes.c_int64,
}
call_args = [] call_args = []
for i in range(len(tensor_list)): for i, tensor in enumerate(tensor_list):
tensor = tensor_list[i]
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor):
if not skip_tensor_validation and not tensor.is_contiguous(): if not skip_tensor_validation and not tensor.is_contiguous():
raise ValueError(f"Input tensor at index {i} must be contiguous") raise ValueError(f"Input tensor at index {i} must be contiguous")
call_args.append(ctypes.c_void_p(tensor.data_ptr())) call_args.append(ctypes.c_void_p(tensor.data_ptr()))
elif isinstance(tensor, int): elif isinstance(tensor, (int, float, bool)):
# Dynamic symbolics which are passed as integer arguments
if i in self.ptr_map: if i in self.ptr_map:
call_args.append(ctypes.c_void_p(tensor)) call_args.append(ctypes.c_void_p(tensor))
else: else:
call_args.append(tensor) dtype = self.param_dtypes[i]
elif isinstance(tensor, float): if dtype not in dtype_to_ctype:
call_args.append(ctypes.c_float(tensor)) raise ValueError(f"Unsupported tensor dtype: {dtype}")
elif isinstance(tensor, bool): call_args.append(dtype_to_ctype[dtype](tensor))
call_args.append(ctypes.c_bool(tensor))
else: else:
raise ValueError(f"Unsupported tensor type: {type(tensor)}") raise ValueError(f"Unsupported tensor type: {type(tensor)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment