Commit 3381ed43 authored by lucasb-eyer's avatar lucasb-eyer
Browse files

Check for valid data in `addPairwiseBilateral` and bail.

parent a9d7be19
...@@ -101,6 +101,10 @@ d.addPairwiseGaussian(sxy=3, compat=3) ...@@ -101,6 +101,10 @@ d.addPairwiseGaussian(sxy=3, compat=3)
d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=im, compat=10) d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=im, compat=10)
``` ```
An important caveat is that `addPairwiseBilateral` only works for RGB images, i.e. three channels.
If your data is of different type than this simple but common case, you'll need to compute your
own pairwise energy using `utils.create_pairwise_bilateral`; see the [generic non-2D case](https://github.com/lucasb-eyer/pydensecrf#generic-non-2d) for details.
### Compatibilities ### Compatibilities
The `compat` argument can be any of the following: The `compat` argument can be any of the following:
......
...@@ -90,3 +90,5 @@ cdef class DenseCRF: ...@@ -90,3 +90,5 @@ cdef class DenseCRF:
cdef class DenseCRF2D(DenseCRF): cdef class DenseCRF2D(DenseCRF):
cdef c_DenseCRF2D *_this2d cdef c_DenseCRF2D *_this2d
cdef int _w
cdef int _h
...@@ -96,6 +96,11 @@ cdef class DenseCRF2D(DenseCRF): ...@@ -96,6 +96,11 @@ cdef class DenseCRF2D(DenseCRF):
if type(self) is DenseCRF2D: if type(self) is DenseCRF2D:
self._this = self._this2d = new c_DenseCRF2D(w, h, nlabels) self._this = self._this2d = new c_DenseCRF2D(w, h, nlabels)
# Unfortunately, self._this2d.W_ and .H_ are protected in C++ and thus
# we cannot access them from here for sanity-checks, so keep our own...
self._w = w
self._h = h
def addPairwiseGaussian(self, sxy, compat, KernelType kernel=DIAG_KERNEL, NormalizationType normalization=NORMALIZE_SYMMETRIC): def addPairwiseGaussian(self, sxy, compat, KernelType kernel=DIAG_KERNEL, NormalizationType normalization=NORMALIZE_SYMMETRIC):
if isinstance(sxy, Number): if isinstance(sxy, Number):
sxy = (sxy, sxy) sxy = (sxy, sxy)
...@@ -109,6 +114,11 @@ cdef class DenseCRF2D(DenseCRF): ...@@ -109,6 +114,11 @@ cdef class DenseCRF2D(DenseCRF):
if isinstance(srgb, Number): if isinstance(srgb, Number):
srgb = (srgb, srgb, srgb) srgb = (srgb, srgb, srgb)
if rgbim.shape[0] != self._h or rgbim.shape[1] != self._w:
raise ValueError("Bad shape for pairwise bilateral (Need {}, got {})".format((self._h, self._w, 3), rgbim.shape))
if rgbim.shape[2] != 3:
raise ValueError("addPairwiseBilateral only works for RGB images. For other types, use `utils.create_pairwise_bilateral` to construct your own pairwise energy and add it through `addPairwiseEnergy`.")
self._this2d.addPairwiseBilateral( self._this2d.addPairwiseBilateral(
sxy[0], sxy[1], srgb[0], srgb[1], srgb[2], &rgbim[0,0,0], _labelcomp(compat), kernel, normalization sxy[0], sxy[1], srgb[0], srgb[1], srgb[2], &rgbim[0,0,0], _labelcomp(compat), kernel, normalization
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment