"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "3feffddd1e8381209a48c24587e36e030051499f"
Commit a47846ee authored by a870a's avatar a870a
Browse files

Fix segmentation cropping and feature map restore issues during inference

parent b0504dc6
......@@ -20,16 +20,21 @@ import numpy as np
from loguru import logger
from nndet.core.boxes.ops import permute_boxes, expand_to_boxes
from nndet.preprocessing.resampling import resample_data_or_seg, get_do_separate_z, get_lowres_axis
def restore_detection(boxes: np.ndarray,
transpose_backward: Sequence[int],
original_spacing: Sequence[float],
spacing_after_resampling: Sequence[float],
crop_bbox: Sequence[Tuple[int, int]],
**kwargs,
) -> np.ndarray:
from nndet.preprocessing.resampling import (
resample_data_or_seg,
get_do_separate_z,
get_lowres_axis,
)
def restore_detection(
boxes: np.ndarray,
transpose_backward: Sequence[int],
original_spacing: Sequence[float],
spacing_after_resampling: Sequence[float],
crop_bbox: Sequence[Tuple[int, int]],
**kwargs,
) -> np.ndarray:
"""
Restore boxes from preprocessed space into original space
......@@ -61,17 +66,18 @@ def restore_detection(boxes: np.ndarray,
return boxes_original
def restore_fmap(fmap: np.ndarray,
transpose_backward: Sequence[int],
original_spacing: Sequence[float],
spacing_after_resampling: Sequence[float],
original_size_before_cropping: Sequence[int],
size_after_cropping: Sequence[int],
crop_bbox: Optional[Sequence[Tuple[int, int]]] = None,
interpolation_order: int = 3,
interpolation_order_z: int = 0,
do_separate_z: bool = None,
) -> np.ndarray:
def restore_fmap(
fmap: np.ndarray,
transpose_backward: Sequence[int],
original_spacing: Sequence[float],
spacing_after_resampling: Sequence[float],
original_size_before_cropping: Sequence[int],
size_after_cropping: Sequence[int],
crop_bbox: Optional[Sequence[Tuple[int, int]]] = None,
interpolation_order: int = 3,
interpolation_order_z: int = 0,
do_separate_z: bool = None,
) -> np.ndarray:
"""
Restore feature map from preprocessed space into original space
......@@ -101,13 +107,21 @@ def restore_fmap(fmap: np.ndarray,
resampled_spacing = spacing_after_resampling[transpose_backward]
if np.any([i != j for i, j in zip(fmap_transposed.shape[1:], size_after_cropping)]):
lowres_axis = _get_lowres_axes(original_spacing, resampled_spacing,
do_separate_z=do_separate_z)
logger.info(f"Resampling: do separate z: {do_separate_z}; lowres axis: {lowres_axis}")
fmap_old_spacing = resample_data_or_seg(fmap_transposed, size_after_cropping, is_seg=False,
axis=lowres_axis, order=interpolation_order,
do_separate_z=do_separate_z,
order_z=interpolation_order_z)
lowres_axis = _get_lowres_axes(
original_spacing, resampled_spacing, do_separate_z=do_separate_z
)
logger.info(
f"Resampling: do separate z: {do_separate_z}; lowres axis: {lowres_axis}"
)
fmap_old_spacing = resample_data_or_seg(
fmap_transposed,
size_after_cropping,
is_seg=False,
axis=lowres_axis,
order=interpolation_order,
do_separate_z=do_separate_z,
order_z=interpolation_order_z,
)
else:
logger.info(f"Resampling: no resampling necessary")
fmap_old_spacing = fmap_transposed
......@@ -118,19 +132,28 @@ def restore_fmap(fmap: np.ndarray,
for c in range(len(crop_bbox)):
crop_bbox[c][1] = np.min(
(crop_bbox[c][0] + fmap_old_spacing.shape[c + 1], original_size_before_cropping[c]))
(
crop_bbox[c][0] + fmap_old_spacing.shape[c + 1],
original_size_before_cropping[c],
)
)
# _slices = [...] + [slice(b[0], b[1]) for b in crop_bbox]
# tmp[_slices] = fmap_old_spacing
_slices = [slice(b[0], b[1]) for b in crop_bbox]
tmp[(..., *_slices)] = fmap_old_spacing
_slices = [...] + [slice(b[0], b[1]) for b in crop_bbox]
tmp[_slices] = fmap_old_spacing
fmap_original = tmp
else:
fmap_original = fmap_old_spacing
return fmap_original
def _get_lowres_axes(original_spacing: Sequence[float],
resampled_spacing: Sequence[float],
do_separate_z: bool) -> Union[Sequence[int], None]:
def _get_lowres_axes(
original_spacing: Sequence[float],
resampled_spacing: Sequence[float],
do_separate_z: bool,
) -> Union[Sequence[int], None]:
"""
Dynamically determine lowres axes
......
......@@ -24,8 +24,10 @@ from skimage.measure import regionprops
import SimpleITK as sitk
def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int],
) -> typing.List[tuple]:
def center_crop_object_mask(
mask: np.ndarray,
cshape: typing.Union[tuple, int],
) -> typing.List[tuple]:
"""
Creates indices to crop patches around individual objects in mask
......@@ -56,8 +58,7 @@ def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int],
cshape = tuple([cshape] * mask.ndim)
if mask.ndim != len(cshape):
raise TypeError("Size of crops needs to be defined for "
"every dimension")
raise TypeError("Size of crops needs to be defined for " "every dimension")
if any(np.subtract(mask.shape, cshape) < 0):
raise TypeError("Patches must be smaller than data.")
......@@ -65,16 +66,21 @@ def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int],
# no objects in mask
return []
all_centroids = [i['centroid'] for i in regionprops(mask.astype(np.int32))]
all_centroids = [i["centroid"] for i in regionprops(mask.astype(np.int32))]
crops = []
for centroid in all_centroids:
crops.append(tuple(slice(int(c) - (s // 2), int(c) + (s // 2))
for c, s in zip(centroid, cshape)))
crops.append(
tuple(
slice(int(c) - (s // 2), int(c) + (s // 2))
for c, s in zip(centroid, cshape)
)
)
return crops
def center_crop_object_seg(seg: np.ndarray, cshape: typing.Union[tuple, int],
**kwargs) -> typing.List[tuple]:
def center_crop_object_seg(
seg: np.ndarray, cshape: typing.Union[tuple, int], **kwargs
) -> typing.List[tuple]:
"""
Creates indices to crop patches around individual objects in segmentation.
Objects are determined by region growing with connected threshold.
......@@ -124,13 +130,15 @@ def create_mask_from_seg(seg: np.ndarray) -> typing.Tuple[np.ndarray, list]:
# choose one seed in segmentation
seed = np.transpose(np.nonzero(_seg))[0]
# invert coordinates for sitk
seed_sitk = tuple(seed[:: -1].tolist())
seed_sitk = tuple(seed[::-1].tolist())
seed = tuple(seed)
# region growing
seg_con = sitk.ConnectedThreshold(_seg_sitk,
seedList=[seed_sitk],
lower=int(_seg[seed]),
upper=int(_seg[seed]))
seg_con = sitk.ConnectedThreshold(
_seg_sitk,
seedList=[seed_sitk],
lower=int(_seg[seed]),
upper=int(_seg[seed]),
)
seg_con = sitk.GetArrayFromImage(seg_con).astype(bool)
# add object to mask
......@@ -146,13 +154,14 @@ def create_mask_from_seg(seg: np.ndarray) -> typing.Tuple[np.ndarray, list]:
return _mask, _obj_cls
def create_grid(cshape: typing.Union[typing.Sequence[int], int],
dshape: typing.Sequence[int],
overlap: typing.Union[typing.Sequence[int], int] = 0,
mode='fixed',
center_boarder: bool = False,
**kwargs,
) -> typing.List[typing.Tuple[slice]]:
def create_grid(
cshape: typing.Union[typing.Sequence[int], int],
dshape: typing.Sequence[int],
overlap: typing.Union[typing.Sequence[int], int] = 0,
mode="fixed",
center_boarder: bool = False,
**kwargs,
) -> typing.List[typing.Tuple[slice]]:
"""
Create indices for a grid
......@@ -205,29 +214,33 @@ def create_grid(cshape: typing.Union[typing.Sequence[int], int],
# check shapes
if len(cshape) != len(dshape):
raise TypeError(
"cshape and dshape must be defined for same dimensionality.")
raise TypeError("cshape and dshape must be defined for same dimensionality.")
if len(overlap) != len(dshape):
raise TypeError(
"overlap and dshape must be defined for same dimensionality.")
raise TypeError("overlap and dshape must be defined for same dimensionality.")
if any(np.subtract(dshape, cshape) < 0):
axes = np.nonzero(np.subtract(dshape, cshape) < 0)
logger.warning(f"Found patch size which is bigger than data: data {dshape} patch {cshape}")
logger.warning(
f"Found patch size which is bigger than data: data {dshape} patch {cshape}"
)
if any(np.subtract(cshape, overlap) < 0):
raise TypeError("Overlap must be smaller than size of patches.")
grid_slices = [_mode_fn[mode](psize, dlim, ov, **kwargs)
for psize, dlim, ov in zip(cshape, dshape, overlap)]
grid_slices = [
_mode_fn[mode](psize, dlim, ov, **kwargs)
for psize, dlim, ov in zip(cshape, dshape, overlap)
]
if center_boarder:
for idx, (psize, dlim, ov) in enumerate(zip(cshape, dshape, overlap)):
lower_bound_start = int(-0.5 * psize)
upper_bound_start = dlim - int(0.5 * psize)
grid_slices[idx] = tuple([
slice(lower_bound_start, lower_bound_start + psize),
*grid_slices[idx],
slice(upper_bound_start, upper_bound_start + psize),
])
grid_slices[idx] = tuple(
[
slice(lower_bound_start, lower_bound_start + psize),
*grid_slices[idx],
slice(upper_bound_start, upper_bound_start + psize),
]
)
if slices_3d is not None:
grid_slices = [tuple([slice(i, i + 1) for i in range(slices_3d)])] + grid_slices
......@@ -235,7 +248,9 @@ def create_grid(cshape: typing.Union[typing.Sequence[int], int],
return grid
def _fixed_slices(psize: int, dlim: int, overlap: int, start: int = 0) -> typing.Tuple[slice]:
def _fixed_slices(
psize: int, dlim: int, overlap: int, start: int = 0
) -> typing.Tuple[slice]:
"""
Creates fixed slicing of a single axis. Only last patch exceeds dlim.
......@@ -286,13 +301,12 @@ def _symmetric_slices(psize: int, dlim: int, overlap: int) -> typing.Tuple[slice
return _fixed_slices(psize, dlim, overlap, start=start)
def save_get_crop(data: np.ndarray,
crop: typing.Sequence[slice],
mode: str = "shift",
**kwargs,
) -> typing.Tuple[np.ndarray,
typing.Tuple[int],
typing.Tuple[slice]]:
def save_get_crop(
data: np.ndarray,
crop: typing.Sequence[slice],
mode: str = "shift",
**kwargs,
) -> typing.Tuple[np.ndarray, typing.Tuple[int], typing.Tuple[slice]]:
"""
Safely extract crops from data
......@@ -318,9 +332,8 @@ def save_get_crop(data: np.ndarray,
interpreted like they were outside the lower boundary!
"""
if len(crop) > data.ndim:
raise TypeError(
"crop must have smaller or same dimensionality as data.")
if mode == 'shift':
raise TypeError("crop must have smaller or same dimensionality as data.")
if mode == "shift":
# move slices if necessary
return _shifted_crop(data, crop)
else:
......@@ -328,11 +341,10 @@ def save_get_crop(data: np.ndarray,
return _padded_crop(data, crop, mode, **kwargs)
def _shifted_crop(data: np.ndarray,
crop: typing.Sequence[slice],
) -> typing.Tuple[np.ndarray,
typing.Tuple[int],
typing.Tuple[slice]]:
def _shifted_crop(
data: np.ndarray,
crop: typing.Sequence[slice],
) -> typing.Tuple[np.ndarray, typing.Tuple[int], typing.Tuple[slice]]:
"""
Created shifted crops to handle borders
......@@ -366,16 +378,20 @@ def _shifted_crop(data: np.ndarray,
if new_slice.stop > dshape[axis + idx]:
raise RuntimeError(
"Patch is bigger than entire data. shift "
"is not supported in this case.")
"is not supported in this case."
)
shifted_crop.append(new_slice)
elif crop_dim.stop > dshape[axis + idx]:
new_slice = \
slice(crop_dim.start - (crop_dim.stop - dshape[axis + idx]),
dshape[axis + idx], crop_dim.step)
new_slice = slice(
crop_dim.start - (crop_dim.stop - dshape[axis + idx]),
dshape[axis + idx],
crop_dim.step,
)
if new_slice.start < 0:
raise RuntimeError(
"Patch is bigger than entire data. shift "
"is not supported in this case.")
"is not supported in this case."
)
shifted_crop.append(new_slice)
else:
shifted_crop.append(crop_dim)
......@@ -383,13 +399,12 @@ def _shifted_crop(data: np.ndarray,
return data[tuple([..., *shifted_crop])], origin, shifted_crop
def _padded_crop(data: np.ndarray,
crop: typing.Sequence[slice],
mode: str,
**kwargs,
) -> typing.Tuple[np.ndarray,
typing.Tuple[int],
typing.Tuple[slice]]:
def _padded_crop(
data: np.ndarray,
crop: typing.Sequence[slice],
mode: str,
**kwargs,
) -> typing.Tuple[np.ndarray, typing.Tuple[int], typing.Tuple[slice]]:
"""
Extract patch from data and pad accordingly
......@@ -429,8 +444,14 @@ def _padded_crop(data: np.ndarray,
padding.append((lower_pad, upper_pad))
clipped_crop.append(slice(lower_bound, upper_bound, crop_dim.step))
origin = [int(x.start) for x in crop]
return (np.pad(data[tuple([..., *clipped_crop])], pad_width=padding, mode=mode, **kwargs),
origin,
clipped_crop,
)
# return (np.pad(data[tuple([..., *clipped_crop])], pad_width=padding, mode=mode, **kwargs),
# origin,
# clipped_crop,
# )
return (
np.pad(
data[tuple([..., *clipped_crop])], pad_width=padding, mode=mode, **kwargs
),
origin,
crop,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment