"vscode:/vscode.git/clone" did not exist on "190e2cdd0b55d289136a177638942e1cd1b2d457"
Unverified Commit 54c70406 authored by Andres Martinez's avatar Andres Martinez Committed by GitHub
Browse files

Merge pull request #320 from MIC-DKFZ/0001_crop_fixes

Fix segmentation cropping and feature map restore issues during inference
parents b0504dc6 a47846ee
...@@ -20,16 +20,21 @@ import numpy as np ...@@ -20,16 +20,21 @@ import numpy as np
from loguru import logger from loguru import logger
from nndet.core.boxes.ops import permute_boxes, expand_to_boxes from nndet.core.boxes.ops import permute_boxes, expand_to_boxes
from nndet.preprocessing.resampling import resample_data_or_seg, get_do_separate_z, get_lowres_axis from nndet.preprocessing.resampling import (
resample_data_or_seg,
get_do_separate_z,
def restore_detection(boxes: np.ndarray, get_lowres_axis,
transpose_backward: Sequence[int], )
original_spacing: Sequence[float],
spacing_after_resampling: Sequence[float],
crop_bbox: Sequence[Tuple[int, int]], def restore_detection(
**kwargs, boxes: np.ndarray,
) -> np.ndarray: transpose_backward: Sequence[int],
original_spacing: Sequence[float],
spacing_after_resampling: Sequence[float],
crop_bbox: Sequence[Tuple[int, int]],
**kwargs,
) -> np.ndarray:
""" """
Restore boxes from preprocessed space into original space Restore boxes from preprocessed space into original space
...@@ -61,17 +66,18 @@ def restore_detection(boxes: np.ndarray, ...@@ -61,17 +66,18 @@ def restore_detection(boxes: np.ndarray,
return boxes_original return boxes_original
def restore_fmap(fmap: np.ndarray, def restore_fmap(
transpose_backward: Sequence[int], fmap: np.ndarray,
original_spacing: Sequence[float], transpose_backward: Sequence[int],
spacing_after_resampling: Sequence[float], original_spacing: Sequence[float],
original_size_before_cropping: Sequence[int], spacing_after_resampling: Sequence[float],
size_after_cropping: Sequence[int], original_size_before_cropping: Sequence[int],
crop_bbox: Optional[Sequence[Tuple[int, int]]] = None, size_after_cropping: Sequence[int],
interpolation_order: int = 3, crop_bbox: Optional[Sequence[Tuple[int, int]]] = None,
interpolation_order_z: int = 0, interpolation_order: int = 3,
do_separate_z: bool = None, interpolation_order_z: int = 0,
) -> np.ndarray: do_separate_z: bool = None,
) -> np.ndarray:
""" """
Restore feature map from preprocessed space into original space Restore feature map from preprocessed space into original space
...@@ -101,13 +107,21 @@ def restore_fmap(fmap: np.ndarray, ...@@ -101,13 +107,21 @@ def restore_fmap(fmap: np.ndarray,
resampled_spacing = spacing_after_resampling[transpose_backward] resampled_spacing = spacing_after_resampling[transpose_backward]
if np.any([i != j for i, j in zip(fmap_transposed.shape[1:], size_after_cropping)]): if np.any([i != j for i, j in zip(fmap_transposed.shape[1:], size_after_cropping)]):
lowres_axis = _get_lowres_axes(original_spacing, resampled_spacing, lowres_axis = _get_lowres_axes(
do_separate_z=do_separate_z) original_spacing, resampled_spacing, do_separate_z=do_separate_z
logger.info(f"Resampling: do separate z: {do_separate_z}; lowres axis: {lowres_axis}") )
fmap_old_spacing = resample_data_or_seg(fmap_transposed, size_after_cropping, is_seg=False, logger.info(
axis=lowres_axis, order=interpolation_order, f"Resampling: do separate z: {do_separate_z}; lowres axis: {lowres_axis}"
do_separate_z=do_separate_z, )
order_z=interpolation_order_z) fmap_old_spacing = resample_data_or_seg(
fmap_transposed,
size_after_cropping,
is_seg=False,
axis=lowres_axis,
order=interpolation_order,
do_separate_z=do_separate_z,
order_z=interpolation_order_z,
)
else: else:
logger.info(f"Resampling: no resampling necessary") logger.info(f"Resampling: no resampling necessary")
fmap_old_spacing = fmap_transposed fmap_old_spacing = fmap_transposed
...@@ -118,19 +132,28 @@ def restore_fmap(fmap: np.ndarray, ...@@ -118,19 +132,28 @@ def restore_fmap(fmap: np.ndarray,
for c in range(len(crop_bbox)): for c in range(len(crop_bbox)):
crop_bbox[c][1] = np.min( crop_bbox[c][1] = np.min(
(crop_bbox[c][0] + fmap_old_spacing.shape[c + 1], original_size_before_cropping[c])) (
crop_bbox[c][0] + fmap_old_spacing.shape[c + 1],
original_size_before_cropping[c],
)
)
# _slices = [...] + [slice(b[0], b[1]) for b in crop_bbox]
# tmp[_slices] = fmap_old_spacing
_slices = [slice(b[0], b[1]) for b in crop_bbox]
tmp[(..., *_slices)] = fmap_old_spacing
_slices = [...] + [slice(b[0], b[1]) for b in crop_bbox]
tmp[_slices] = fmap_old_spacing
fmap_original = tmp fmap_original = tmp
else: else:
fmap_original = fmap_old_spacing fmap_original = fmap_old_spacing
return fmap_original return fmap_original
def _get_lowres_axes(original_spacing: Sequence[float], def _get_lowres_axes(
resampled_spacing: Sequence[float], original_spacing: Sequence[float],
do_separate_z: bool) -> Union[Sequence[int], None]: resampled_spacing: Sequence[float],
do_separate_z: bool,
) -> Union[Sequence[int], None]:
""" """
Dynamically determine lowres axes Dynamically determine lowres axes
......
...@@ -24,8 +24,10 @@ from skimage.measure import regionprops ...@@ -24,8 +24,10 @@ from skimage.measure import regionprops
import SimpleITK as sitk import SimpleITK as sitk
def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int], def center_crop_object_mask(
) -> typing.List[tuple]: mask: np.ndarray,
cshape: typing.Union[tuple, int],
) -> typing.List[tuple]:
""" """
Creates indices to crop patches around individual objects in mask Creates indices to crop patches around individual objects in mask
...@@ -56,8 +58,7 @@ def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int], ...@@ -56,8 +58,7 @@ def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int],
cshape = tuple([cshape] * mask.ndim) cshape = tuple([cshape] * mask.ndim)
if mask.ndim != len(cshape): if mask.ndim != len(cshape):
raise TypeError("Size of crops needs to be defined for " raise TypeError("Size of crops needs to be defined for " "every dimension")
"every dimension")
if any(np.subtract(mask.shape, cshape) < 0): if any(np.subtract(mask.shape, cshape) < 0):
raise TypeError("Patches must be smaller than data.") raise TypeError("Patches must be smaller than data.")
...@@ -65,16 +66,21 @@ def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int], ...@@ -65,16 +66,21 @@ def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int],
# no objects in mask # no objects in mask
return [] return []
all_centroids = [i['centroid'] for i in regionprops(mask.astype(np.int32))] all_centroids = [i["centroid"] for i in regionprops(mask.astype(np.int32))]
crops = [] crops = []
for centroid in all_centroids: for centroid in all_centroids:
crops.append(tuple(slice(int(c) - (s // 2), int(c) + (s // 2)) crops.append(
for c, s in zip(centroid, cshape))) tuple(
slice(int(c) - (s // 2), int(c) + (s // 2))
for c, s in zip(centroid, cshape)
)
)
return crops return crops
def center_crop_object_seg(seg: np.ndarray, cshape: typing.Union[tuple, int], def center_crop_object_seg(
**kwargs) -> typing.List[tuple]: seg: np.ndarray, cshape: typing.Union[tuple, int], **kwargs
) -> typing.List[tuple]:
""" """
Creates indices to crop patches around individual objects in segmentation. Creates indices to crop patches around individual objects in segmentation.
Objects are determined by region growing with connected threshold. Objects are determined by region growing with connected threshold.
...@@ -124,13 +130,15 @@ def create_mask_from_seg(seg: np.ndarray) -> typing.Tuple[np.ndarray, list]: ...@@ -124,13 +130,15 @@ def create_mask_from_seg(seg: np.ndarray) -> typing.Tuple[np.ndarray, list]:
# choose one seed in segmentation # choose one seed in segmentation
seed = np.transpose(np.nonzero(_seg))[0] seed = np.transpose(np.nonzero(_seg))[0]
# invert coordinates for sitk # invert coordinates for sitk
seed_sitk = tuple(seed[:: -1].tolist()) seed_sitk = tuple(seed[::-1].tolist())
seed = tuple(seed) seed = tuple(seed)
# region growing # region growing
seg_con = sitk.ConnectedThreshold(_seg_sitk, seg_con = sitk.ConnectedThreshold(
seedList=[seed_sitk], _seg_sitk,
lower=int(_seg[seed]), seedList=[seed_sitk],
upper=int(_seg[seed])) lower=int(_seg[seed]),
upper=int(_seg[seed]),
)
seg_con = sitk.GetArrayFromImage(seg_con).astype(bool) seg_con = sitk.GetArrayFromImage(seg_con).astype(bool)
# add object to mask # add object to mask
...@@ -146,13 +154,14 @@ def create_mask_from_seg(seg: np.ndarray) -> typing.Tuple[np.ndarray, list]: ...@@ -146,13 +154,14 @@ def create_mask_from_seg(seg: np.ndarray) -> typing.Tuple[np.ndarray, list]:
return _mask, _obj_cls return _mask, _obj_cls
def create_grid(cshape: typing.Union[typing.Sequence[int], int], def create_grid(
dshape: typing.Sequence[int], cshape: typing.Union[typing.Sequence[int], int],
overlap: typing.Union[typing.Sequence[int], int] = 0, dshape: typing.Sequence[int],
mode='fixed', overlap: typing.Union[typing.Sequence[int], int] = 0,
center_boarder: bool = False, mode="fixed",
**kwargs, center_boarder: bool = False,
) -> typing.List[typing.Tuple[slice]]: **kwargs,
) -> typing.List[typing.Tuple[slice]]:
""" """
Create indices for a grid Create indices for a grid
...@@ -205,29 +214,33 @@ def create_grid(cshape: typing.Union[typing.Sequence[int], int], ...@@ -205,29 +214,33 @@ def create_grid(cshape: typing.Union[typing.Sequence[int], int],
# check shapes # check shapes
if len(cshape) != len(dshape): if len(cshape) != len(dshape):
raise TypeError( raise TypeError("cshape and dshape must be defined for same dimensionality.")
"cshape and dshape must be defined for same dimensionality.")
if len(overlap) != len(dshape): if len(overlap) != len(dshape):
raise TypeError( raise TypeError("overlap and dshape must be defined for same dimensionality.")
"overlap and dshape must be defined for same dimensionality.")
if any(np.subtract(dshape, cshape) < 0): if any(np.subtract(dshape, cshape) < 0):
axes = np.nonzero(np.subtract(dshape, cshape) < 0) axes = np.nonzero(np.subtract(dshape, cshape) < 0)
logger.warning(f"Found patch size which is bigger than data: data {dshape} patch {cshape}") logger.warning(
f"Found patch size which is bigger than data: data {dshape} patch {cshape}"
)
if any(np.subtract(cshape, overlap) < 0): if any(np.subtract(cshape, overlap) < 0):
raise TypeError("Overlap must be smaller than size of patches.") raise TypeError("Overlap must be smaller than size of patches.")
grid_slices = [_mode_fn[mode](psize, dlim, ov, **kwargs) grid_slices = [
for psize, dlim, ov in zip(cshape, dshape, overlap)] _mode_fn[mode](psize, dlim, ov, **kwargs)
for psize, dlim, ov in zip(cshape, dshape, overlap)
]
if center_boarder: if center_boarder:
for idx, (psize, dlim, ov) in enumerate(zip(cshape, dshape, overlap)): for idx, (psize, dlim, ov) in enumerate(zip(cshape, dshape, overlap)):
lower_bound_start = int(-0.5 * psize) lower_bound_start = int(-0.5 * psize)
upper_bound_start = dlim - int(0.5 * psize) upper_bound_start = dlim - int(0.5 * psize)
grid_slices[idx] = tuple([ grid_slices[idx] = tuple(
slice(lower_bound_start, lower_bound_start + psize), [
*grid_slices[idx], slice(lower_bound_start, lower_bound_start + psize),
slice(upper_bound_start, upper_bound_start + psize), *grid_slices[idx],
]) slice(upper_bound_start, upper_bound_start + psize),
]
)
if slices_3d is not None: if slices_3d is not None:
grid_slices = [tuple([slice(i, i + 1) for i in range(slices_3d)])] + grid_slices grid_slices = [tuple([slice(i, i + 1) for i in range(slices_3d)])] + grid_slices
...@@ -235,7 +248,9 @@ def create_grid(cshape: typing.Union[typing.Sequence[int], int], ...@@ -235,7 +248,9 @@ def create_grid(cshape: typing.Union[typing.Sequence[int], int],
return grid return grid
def _fixed_slices(psize: int, dlim: int, overlap: int, start: int = 0) -> typing.Tuple[slice]: def _fixed_slices(
psize: int, dlim: int, overlap: int, start: int = 0
) -> typing.Tuple[slice]:
""" """
Creates fixed slicing of a single axis. Only last patch exceeds dlim. Creates fixed slicing of a single axis. Only last patch exceeds dlim.
...@@ -286,13 +301,12 @@ def _symmetric_slices(psize: int, dlim: int, overlap: int) -> typing.Tuple[slice ...@@ -286,13 +301,12 @@ def _symmetric_slices(psize: int, dlim: int, overlap: int) -> typing.Tuple[slice
return _fixed_slices(psize, dlim, overlap, start=start) return _fixed_slices(psize, dlim, overlap, start=start)
def save_get_crop(data: np.ndarray, def save_get_crop(
crop: typing.Sequence[slice], data: np.ndarray,
mode: str = "shift", crop: typing.Sequence[slice],
**kwargs, mode: str = "shift",
) -> typing.Tuple[np.ndarray, **kwargs,
typing.Tuple[int], ) -> typing.Tuple[np.ndarray, typing.Tuple[int], typing.Tuple[slice]]:
typing.Tuple[slice]]:
""" """
Safely extract crops from data Safely extract crops from data
...@@ -318,9 +332,8 @@ def save_get_crop(data: np.ndarray, ...@@ -318,9 +332,8 @@ def save_get_crop(data: np.ndarray,
interpreted like they were outside the lower boundary! interpreted like they were outside the lower boundary!
""" """
if len(crop) > data.ndim: if len(crop) > data.ndim:
raise TypeError( raise TypeError("crop must have smaller or same dimensionality as data.")
"crop must have smaller or same dimensionality as data.") if mode == "shift":
if mode == 'shift':
# move slices if necessary # move slices if necessary
return _shifted_crop(data, crop) return _shifted_crop(data, crop)
else: else:
...@@ -328,11 +341,10 @@ def save_get_crop(data: np.ndarray, ...@@ -328,11 +341,10 @@ def save_get_crop(data: np.ndarray,
return _padded_crop(data, crop, mode, **kwargs) return _padded_crop(data, crop, mode, **kwargs)
def _shifted_crop(data: np.ndarray, def _shifted_crop(
crop: typing.Sequence[slice], data: np.ndarray,
) -> typing.Tuple[np.ndarray, crop: typing.Sequence[slice],
typing.Tuple[int], ) -> typing.Tuple[np.ndarray, typing.Tuple[int], typing.Tuple[slice]]:
typing.Tuple[slice]]:
""" """
Created shifted crops to handle borders Created shifted crops to handle borders
...@@ -366,16 +378,20 @@ def _shifted_crop(data: np.ndarray, ...@@ -366,16 +378,20 @@ def _shifted_crop(data: np.ndarray,
if new_slice.stop > dshape[axis + idx]: if new_slice.stop > dshape[axis + idx]:
raise RuntimeError( raise RuntimeError(
"Patch is bigger than entire data. shift " "Patch is bigger than entire data. shift "
"is not supported in this case.") "is not supported in this case."
)
shifted_crop.append(new_slice) shifted_crop.append(new_slice)
elif crop_dim.stop > dshape[axis + idx]: elif crop_dim.stop > dshape[axis + idx]:
new_slice = \ new_slice = slice(
slice(crop_dim.start - (crop_dim.stop - dshape[axis + idx]), crop_dim.start - (crop_dim.stop - dshape[axis + idx]),
dshape[axis + idx], crop_dim.step) dshape[axis + idx],
crop_dim.step,
)
if new_slice.start < 0: if new_slice.start < 0:
raise RuntimeError( raise RuntimeError(
"Patch is bigger than entire data. shift " "Patch is bigger than entire data. shift "
"is not supported in this case.") "is not supported in this case."
)
shifted_crop.append(new_slice) shifted_crop.append(new_slice)
else: else:
shifted_crop.append(crop_dim) shifted_crop.append(crop_dim)
...@@ -383,13 +399,12 @@ def _shifted_crop(data: np.ndarray, ...@@ -383,13 +399,12 @@ def _shifted_crop(data: np.ndarray,
return data[tuple([..., *shifted_crop])], origin, shifted_crop return data[tuple([..., *shifted_crop])], origin, shifted_crop
def _padded_crop(data: np.ndarray, def _padded_crop(
crop: typing.Sequence[slice], data: np.ndarray,
mode: str, crop: typing.Sequence[slice],
**kwargs, mode: str,
) -> typing.Tuple[np.ndarray, **kwargs,
typing.Tuple[int], ) -> typing.Tuple[np.ndarray, typing.Tuple[int], typing.Tuple[slice]]:
typing.Tuple[slice]]:
""" """
Extract patch from data and pad accordingly Extract patch from data and pad accordingly
...@@ -429,8 +444,14 @@ def _padded_crop(data: np.ndarray, ...@@ -429,8 +444,14 @@ def _padded_crop(data: np.ndarray,
padding.append((lower_pad, upper_pad)) padding.append((lower_pad, upper_pad))
clipped_crop.append(slice(lower_bound, upper_bound, crop_dim.step)) clipped_crop.append(slice(lower_bound, upper_bound, crop_dim.step))
origin = [int(x.start) for x in crop] origin = [int(x.start) for x in crop]
return (np.pad(data[tuple([..., *clipped_crop])], pad_width=padding, mode=mode, **kwargs), # return (np.pad(data[tuple([..., *clipped_crop])], pad_width=padding, mode=mode, **kwargs),
origin, # origin,
clipped_crop, # clipped_crop,
) # )
return (
np.pad(
data[tuple([..., *clipped_crop])], pad_width=padding, mode=mode, **kwargs
),
origin,
crop,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment