"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "01d155c96943acf7e3dec4a5ace7cc6704a02b27"
Commit 9096c7bf authored by lucasb-eyer's avatar lucasb-eyer
Browse files

Fix segfault on wrong compat dtype.

The exception is now correctly propagated instead.
parent 1c40e00f
...@@ -8,7 +8,7 @@ import eigen ...@@ -8,7 +8,7 @@ import eigen
cimport eigen cimport eigen
cdef LabelCompatibility* _labelcomp(compat): cdef LabelCompatibility* _labelcomp(compat) except NULL:
if isinstance(compat, Number): if isinstance(compat, Number):
return new PottsCompatibility(compat) return new PottsCompatibility(compat)
elif memoryview(compat).ndim == 1: elif memoryview(compat).ndim == 1:
...@@ -17,6 +17,7 @@ cdef LabelCompatibility* _labelcomp(compat): ...@@ -17,6 +17,7 @@ cdef LabelCompatibility* _labelcomp(compat):
return new MatrixCompatibility(eigen.c_matrixXf(compat)) return new MatrixCompatibility(eigen.c_matrixXf(compat))
else: else:
raise ValueError("LabelCompatibility of dimension >2 not meaningful.") raise ValueError("LabelCompatibility of dimension >2 not meaningful.")
return NULL # Important for the exception(s) to propagate!
cdef class Unary: cdef class Unary:
......
import numpy as np import numpy as np
import densecrf as dcrf import pydensecrf.densecrf as dcrf
# TODO: Make this real unit-tests some time in the future...
# Tests for specific issues
###########################
# Via e-mail: crash when non-float32 compat
d = dcrf.DenseCRF2D(10,10,2)
d.setUnaryEnergy(np.ones((2,10*10), dtype=np.float32))
compat = np.array([1.0, 2.0])
try:
d.addPairwiseBilateral(sxy=(3,3), srgb=(3,3,3), rgbim=np.zeros((10,10,3), np.uint8), compat=compat)
d.inference(2)
raise TypeError("Didn't raise an exception, but should because compat dtypes don't match!!")
except ValueError:
pass # That's what we want!
# The following is not a really good unittest, but was the first tests.
###########################
# d = densecrf.PyDenseCRF2D(3, 2, 3) # d = densecrf.PyDenseCRF2D(3, 2, 3)
# U = np.full((3,6), 0.1, dtype=np.float32) # U = np.full((3,6), 0.1, dtype=np.float32)
...@@ -25,3 +45,4 @@ d.setUnaryEnergy(-np.log(Up)) ...@@ -25,3 +45,4 @@ d.setUnaryEnergy(-np.log(Up))
d.addPairwiseBilateral(2, 2, img, 3) d.addPairwiseBilateral(2, 2, img, 3)
# d.addPairwiseBilateral(2, 2, img, 3) # d.addPairwiseBilateral(2, 2, img, 3)
np.argmax(d.inference(10), axis=0).reshape(10,10) np.argmax(d.inference(10), axis=0).reshape(10,10)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment