Unverified Commit fff24aee authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Add stride index validation in CythonKernelWrapper (#743)

* Introduced an assertion to ensure that the stride index is within the valid range of tensor dimensions in `cython_wrapper.pyx`.
* This change prevents potential out-of-bounds errors when accessing tensor dimensions, enhancing the robustness of the code.
parent 72be4909
...@@ -123,6 +123,11 @@ cdef class CythonKernelWrapper: ...@@ -123,6 +123,11 @@ cdef class CythonKernelWrapper:
# otherwise, maybe torch.data_ptr() for T.ptr inputs # otherwise, maybe torch.data_ptr() for T.ptr inputs
continue continue
for stride_idx, expected_stride in strides_list: for stride_idx, expected_stride in strides_list:
# Ensure the stride index is within the valid range of tensor dimensions
# (stride_idx should be less than the number of dimensions of the tensor)
assert stride_idx < tensor.dim(), f"Stride index {stride_idx} out of bounds for tensor with {tensor.dim()} dimensions"
if tensor.shape[stride_idx] == 1:
continue
actual_stride = tensor.stride(stride_idx) actual_stride = tensor.stride(stride_idx)
if actual_stride != expected_stride: if actual_stride != expected_stride:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment