Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yaoyuping
nnDetection
Commits
a47846ee
"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "3feffddd1e8381209a48c24587e36e030051499f"
Commit
a47846ee
authored
Jun 26, 2025
by
a870a
Browse files
Fix segmentation cropping and feature map restore issues during inference
parent
b0504dc6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
144 additions
and
100 deletions
+144
-100
nndet/inference/restore.py
nndet/inference/restore.py
+57
-34
nndet/io/patching.py
nndet/io/patching.py
+87
-66
No files found.
nndet/inference/restore.py
View file @
a47846ee
...
...
@@ -20,16 +20,21 @@ import numpy as np
from
loguru
import
logger
from
nndet.core.boxes.ops
import
permute_boxes
,
expand_to_boxes
from
nndet.preprocessing.resampling
import
resample_data_or_seg
,
get_do_separate_z
,
get_lowres_axis
def
restore_detection
(
boxes
:
np
.
ndarray
,
transpose_backward
:
Sequence
[
int
],
original_spacing
:
Sequence
[
float
],
spacing_after_resampling
:
Sequence
[
float
],
crop_bbox
:
Sequence
[
Tuple
[
int
,
int
]],
**
kwargs
,
)
->
np
.
ndarray
:
from
nndet.preprocessing.resampling
import
(
resample_data_or_seg
,
get_do_separate_z
,
get_lowres_axis
,
)
def
restore_detection
(
boxes
:
np
.
ndarray
,
transpose_backward
:
Sequence
[
int
],
original_spacing
:
Sequence
[
float
],
spacing_after_resampling
:
Sequence
[
float
],
crop_bbox
:
Sequence
[
Tuple
[
int
,
int
]],
**
kwargs
,
)
->
np
.
ndarray
:
"""
Restore boxes from preprocessed space into original space
...
...
@@ -61,17 +66,18 @@ def restore_detection(boxes: np.ndarray,
return
boxes_original
def
restore_fmap
(
fmap
:
np
.
ndarray
,
transpose_backward
:
Sequence
[
int
],
original_spacing
:
Sequence
[
float
],
spacing_after_resampling
:
Sequence
[
float
],
original_size_before_cropping
:
Sequence
[
int
],
size_after_cropping
:
Sequence
[
int
],
crop_bbox
:
Optional
[
Sequence
[
Tuple
[
int
,
int
]]]
=
None
,
interpolation_order
:
int
=
3
,
interpolation_order_z
:
int
=
0
,
do_separate_z
:
bool
=
None
,
)
->
np
.
ndarray
:
def
restore_fmap
(
fmap
:
np
.
ndarray
,
transpose_backward
:
Sequence
[
int
],
original_spacing
:
Sequence
[
float
],
spacing_after_resampling
:
Sequence
[
float
],
original_size_before_cropping
:
Sequence
[
int
],
size_after_cropping
:
Sequence
[
int
],
crop_bbox
:
Optional
[
Sequence
[
Tuple
[
int
,
int
]]]
=
None
,
interpolation_order
:
int
=
3
,
interpolation_order_z
:
int
=
0
,
do_separate_z
:
bool
=
None
,
)
->
np
.
ndarray
:
"""
Restore feature map from preprocessed space into original space
...
...
@@ -101,13 +107,21 @@ def restore_fmap(fmap: np.ndarray,
resampled_spacing
=
spacing_after_resampling
[
transpose_backward
]
if
np
.
any
([
i
!=
j
for
i
,
j
in
zip
(
fmap_transposed
.
shape
[
1
:],
size_after_cropping
)]):
lowres_axis
=
_get_lowres_axes
(
original_spacing
,
resampled_spacing
,
do_separate_z
=
do_separate_z
)
logger
.
info
(
f
"Resampling: do separate z:
{
do_separate_z
}
; lowres axis:
{
lowres_axis
}
"
)
fmap_old_spacing
=
resample_data_or_seg
(
fmap_transposed
,
size_after_cropping
,
is_seg
=
False
,
axis
=
lowres_axis
,
order
=
interpolation_order
,
do_separate_z
=
do_separate_z
,
order_z
=
interpolation_order_z
)
lowres_axis
=
_get_lowres_axes
(
original_spacing
,
resampled_spacing
,
do_separate_z
=
do_separate_z
)
logger
.
info
(
f
"Resampling: do separate z:
{
do_separate_z
}
; lowres axis:
{
lowres_axis
}
"
)
fmap_old_spacing
=
resample_data_or_seg
(
fmap_transposed
,
size_after_cropping
,
is_seg
=
False
,
axis
=
lowres_axis
,
order
=
interpolation_order
,
do_separate_z
=
do_separate_z
,
order_z
=
interpolation_order_z
,
)
else
:
logger
.
info
(
f
"Resampling: no resampling necessary"
)
fmap_old_spacing
=
fmap_transposed
...
...
@@ -118,19 +132,28 @@ def restore_fmap(fmap: np.ndarray,
for
c
in
range
(
len
(
crop_bbox
)):
crop_bbox
[
c
][
1
]
=
np
.
min
(
(
crop_bbox
[
c
][
0
]
+
fmap_old_spacing
.
shape
[
c
+
1
],
original_size_before_cropping
[
c
]))
(
crop_bbox
[
c
][
0
]
+
fmap_old_spacing
.
shape
[
c
+
1
],
original_size_before_cropping
[
c
],
)
)
# _slices = [...] + [slice(b[0], b[1]) for b in crop_bbox]
# tmp[_slices] = fmap_old_spacing
_slices
=
[
slice
(
b
[
0
],
b
[
1
])
for
b
in
crop_bbox
]
tmp
[(...,
*
_slices
)]
=
fmap_old_spacing
_slices
=
[...]
+
[
slice
(
b
[
0
],
b
[
1
])
for
b
in
crop_bbox
]
tmp
[
_slices
]
=
fmap_old_spacing
fmap_original
=
tmp
else
:
fmap_original
=
fmap_old_spacing
return
fmap_original
def
_get_lowres_axes
(
original_spacing
:
Sequence
[
float
],
resampled_spacing
:
Sequence
[
float
],
do_separate_z
:
bool
)
->
Union
[
Sequence
[
int
],
None
]:
def
_get_lowres_axes
(
original_spacing
:
Sequence
[
float
],
resampled_spacing
:
Sequence
[
float
],
do_separate_z
:
bool
,
)
->
Union
[
Sequence
[
int
],
None
]:
"""
Dynamically determine lowres axes
...
...
nndet/io/patching.py
View file @
a47846ee
...
...
@@ -24,8 +24,10 @@ from skimage.measure import regionprops
import
SimpleITK
as
sitk
def
center_crop_object_mask
(
mask
:
np
.
ndarray
,
cshape
:
typing
.
Union
[
tuple
,
int
],
)
->
typing
.
List
[
tuple
]:
def
center_crop_object_mask
(
mask
:
np
.
ndarray
,
cshape
:
typing
.
Union
[
tuple
,
int
],
)
->
typing
.
List
[
tuple
]:
"""
Creates indices to crop patches around individual objects in mask
...
...
@@ -56,8 +58,7 @@ def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int],
cshape
=
tuple
([
cshape
]
*
mask
.
ndim
)
if
mask
.
ndim
!=
len
(
cshape
):
raise
TypeError
(
"Size of crops needs to be defined for "
"every dimension"
)
raise
TypeError
(
"Size of crops needs to be defined for "
"every dimension"
)
if
any
(
np
.
subtract
(
mask
.
shape
,
cshape
)
<
0
):
raise
TypeError
(
"Patches must be smaller than data."
)
...
...
@@ -65,16 +66,21 @@ def center_crop_object_mask(mask: np.ndarray, cshape: typing.Union[tuple, int],
# no objects in mask
return
[]
all_centroids
=
[
i
[
'
centroid
'
]
for
i
in
regionprops
(
mask
.
astype
(
np
.
int32
))]
all_centroids
=
[
i
[
"
centroid
"
]
for
i
in
regionprops
(
mask
.
astype
(
np
.
int32
))]
crops
=
[]
for
centroid
in
all_centroids
:
crops
.
append
(
tuple
(
slice
(
int
(
c
)
-
(
s
//
2
),
int
(
c
)
+
(
s
//
2
))
for
c
,
s
in
zip
(
centroid
,
cshape
)))
crops
.
append
(
tuple
(
slice
(
int
(
c
)
-
(
s
//
2
),
int
(
c
)
+
(
s
//
2
))
for
c
,
s
in
zip
(
centroid
,
cshape
)
)
)
return
crops
def
center_crop_object_seg
(
seg
:
np
.
ndarray
,
cshape
:
typing
.
Union
[
tuple
,
int
],
**
kwargs
)
->
typing
.
List
[
tuple
]:
def
center_crop_object_seg
(
seg
:
np
.
ndarray
,
cshape
:
typing
.
Union
[
tuple
,
int
],
**
kwargs
)
->
typing
.
List
[
tuple
]:
"""
Creates indices to crop patches around individual objects in segmentation.
Objects are determined by region growing with connected threshold.
...
...
@@ -124,13 +130,15 @@ def create_mask_from_seg(seg: np.ndarray) -> typing.Tuple[np.ndarray, list]:
# choose one seed in segmentation
seed
=
np
.
transpose
(
np
.
nonzero
(
_seg
))[
0
]
# invert coordinates for sitk
seed_sitk
=
tuple
(
seed
[::
-
1
].
tolist
())
seed_sitk
=
tuple
(
seed
[::
-
1
].
tolist
())
seed
=
tuple
(
seed
)
# region growing
seg_con
=
sitk
.
ConnectedThreshold
(
_seg_sitk
,
seedList
=
[
seed_sitk
],
lower
=
int
(
_seg
[
seed
]),
upper
=
int
(
_seg
[
seed
]))
seg_con
=
sitk
.
ConnectedThreshold
(
_seg_sitk
,
seedList
=
[
seed_sitk
],
lower
=
int
(
_seg
[
seed
]),
upper
=
int
(
_seg
[
seed
]),
)
seg_con
=
sitk
.
GetArrayFromImage
(
seg_con
).
astype
(
bool
)
# add object to mask
...
...
@@ -146,13 +154,14 @@ def create_mask_from_seg(seg: np.ndarray) -> typing.Tuple[np.ndarray, list]:
return
_mask
,
_obj_cls
def
create_grid
(
cshape
:
typing
.
Union
[
typing
.
Sequence
[
int
],
int
],
dshape
:
typing
.
Sequence
[
int
],
overlap
:
typing
.
Union
[
typing
.
Sequence
[
int
],
int
]
=
0
,
mode
=
'fixed'
,
center_boarder
:
bool
=
False
,
**
kwargs
,
)
->
typing
.
List
[
typing
.
Tuple
[
slice
]]:
def
create_grid
(
cshape
:
typing
.
Union
[
typing
.
Sequence
[
int
],
int
],
dshape
:
typing
.
Sequence
[
int
],
overlap
:
typing
.
Union
[
typing
.
Sequence
[
int
],
int
]
=
0
,
mode
=
"fixed"
,
center_boarder
:
bool
=
False
,
**
kwargs
,
)
->
typing
.
List
[
typing
.
Tuple
[
slice
]]:
"""
Create indices for a grid
...
...
@@ -205,29 +214,33 @@ def create_grid(cshape: typing.Union[typing.Sequence[int], int],
# check shapes
if
len
(
cshape
)
!=
len
(
dshape
):
raise
TypeError
(
"cshape and dshape must be defined for same dimensionality."
)
raise
TypeError
(
"cshape and dshape must be defined for same dimensionality."
)
if
len
(
overlap
)
!=
len
(
dshape
):
raise
TypeError
(
"overlap and dshape must be defined for same dimensionality."
)
raise
TypeError
(
"overlap and dshape must be defined for same dimensionality."
)
if
any
(
np
.
subtract
(
dshape
,
cshape
)
<
0
):
axes
=
np
.
nonzero
(
np
.
subtract
(
dshape
,
cshape
)
<
0
)
logger
.
warning
(
f
"Found patch size which is bigger than data: data
{
dshape
}
patch
{
cshape
}
"
)
logger
.
warning
(
f
"Found patch size which is bigger than data: data
{
dshape
}
patch
{
cshape
}
"
)
if
any
(
np
.
subtract
(
cshape
,
overlap
)
<
0
):
raise
TypeError
(
"Overlap must be smaller than size of patches."
)
grid_slices
=
[
_mode_fn
[
mode
](
psize
,
dlim
,
ov
,
**
kwargs
)
for
psize
,
dlim
,
ov
in
zip
(
cshape
,
dshape
,
overlap
)]
grid_slices
=
[
_mode_fn
[
mode
](
psize
,
dlim
,
ov
,
**
kwargs
)
for
psize
,
dlim
,
ov
in
zip
(
cshape
,
dshape
,
overlap
)
]
if
center_boarder
:
for
idx
,
(
psize
,
dlim
,
ov
)
in
enumerate
(
zip
(
cshape
,
dshape
,
overlap
)):
lower_bound_start
=
int
(
-
0.5
*
psize
)
upper_bound_start
=
dlim
-
int
(
0.5
*
psize
)
grid_slices
[
idx
]
=
tuple
([
slice
(
lower_bound_start
,
lower_bound_start
+
psize
),
*
grid_slices
[
idx
],
slice
(
upper_bound_start
,
upper_bound_start
+
psize
),
])
grid_slices
[
idx
]
=
tuple
(
[
slice
(
lower_bound_start
,
lower_bound_start
+
psize
),
*
grid_slices
[
idx
],
slice
(
upper_bound_start
,
upper_bound_start
+
psize
),
]
)
if
slices_3d
is
not
None
:
grid_slices
=
[
tuple
([
slice
(
i
,
i
+
1
)
for
i
in
range
(
slices_3d
)])]
+
grid_slices
...
...
@@ -235,7 +248,9 @@ def create_grid(cshape: typing.Union[typing.Sequence[int], int],
return
grid
def
_fixed_slices
(
psize
:
int
,
dlim
:
int
,
overlap
:
int
,
start
:
int
=
0
)
->
typing
.
Tuple
[
slice
]:
def
_fixed_slices
(
psize
:
int
,
dlim
:
int
,
overlap
:
int
,
start
:
int
=
0
)
->
typing
.
Tuple
[
slice
]:
"""
Creates fixed slicing of a single axis. Only last patch exceeds dlim.
...
...
@@ -286,13 +301,12 @@ def _symmetric_slices(psize: int, dlim: int, overlap: int) -> typing.Tuple[slice
return
_fixed_slices
(
psize
,
dlim
,
overlap
,
start
=
start
)
def
save_get_crop
(
data
:
np
.
ndarray
,
crop
:
typing
.
Sequence
[
slice
],
mode
:
str
=
"shift"
,
**
kwargs
,
)
->
typing
.
Tuple
[
np
.
ndarray
,
typing
.
Tuple
[
int
],
typing
.
Tuple
[
slice
]]:
def
save_get_crop
(
data
:
np
.
ndarray
,
crop
:
typing
.
Sequence
[
slice
],
mode
:
str
=
"shift"
,
**
kwargs
,
)
->
typing
.
Tuple
[
np
.
ndarray
,
typing
.
Tuple
[
int
],
typing
.
Tuple
[
slice
]]:
"""
Safely extract crops from data
...
...
@@ -318,9 +332,8 @@ def save_get_crop(data: np.ndarray,
interpreted like they were outside the lower boundary!
"""
if
len
(
crop
)
>
data
.
ndim
:
raise
TypeError
(
"crop must have smaller or same dimensionality as data."
)
if
mode
==
'shift'
:
raise
TypeError
(
"crop must have smaller or same dimensionality as data."
)
if
mode
==
"shift"
:
# move slices if necessary
return
_shifted_crop
(
data
,
crop
)
else
:
...
...
@@ -328,11 +341,10 @@ def save_get_crop(data: np.ndarray,
return
_padded_crop
(
data
,
crop
,
mode
,
**
kwargs
)
def
_shifted_crop
(
data
:
np
.
ndarray
,
crop
:
typing
.
Sequence
[
slice
],
)
->
typing
.
Tuple
[
np
.
ndarray
,
typing
.
Tuple
[
int
],
typing
.
Tuple
[
slice
]]:
def
_shifted_crop
(
data
:
np
.
ndarray
,
crop
:
typing
.
Sequence
[
slice
],
)
->
typing
.
Tuple
[
np
.
ndarray
,
typing
.
Tuple
[
int
],
typing
.
Tuple
[
slice
]]:
"""
Created shifted crops to handle borders
...
...
@@ -366,16 +378,20 @@ def _shifted_crop(data: np.ndarray,
if
new_slice
.
stop
>
dshape
[
axis
+
idx
]:
raise
RuntimeError
(
"Patch is bigger than entire data. shift "
"is not supported in this case."
)
"is not supported in this case."
)
shifted_crop
.
append
(
new_slice
)
elif
crop_dim
.
stop
>
dshape
[
axis
+
idx
]:
new_slice
=
\
slice
(
crop_dim
.
start
-
(
crop_dim
.
stop
-
dshape
[
axis
+
idx
]),
dshape
[
axis
+
idx
],
crop_dim
.
step
)
new_slice
=
slice
(
crop_dim
.
start
-
(
crop_dim
.
stop
-
dshape
[
axis
+
idx
]),
dshape
[
axis
+
idx
],
crop_dim
.
step
,
)
if
new_slice
.
start
<
0
:
raise
RuntimeError
(
"Patch is bigger than entire data. shift "
"is not supported in this case."
)
"is not supported in this case."
)
shifted_crop
.
append
(
new_slice
)
else
:
shifted_crop
.
append
(
crop_dim
)
...
...
@@ -383,13 +399,12 @@ def _shifted_crop(data: np.ndarray,
return
data
[
tuple
([...,
*
shifted_crop
])],
origin
,
shifted_crop
def
_padded_crop
(
data
:
np
.
ndarray
,
crop
:
typing
.
Sequence
[
slice
],
mode
:
str
,
**
kwargs
,
)
->
typing
.
Tuple
[
np
.
ndarray
,
typing
.
Tuple
[
int
],
typing
.
Tuple
[
slice
]]:
def
_padded_crop
(
data
:
np
.
ndarray
,
crop
:
typing
.
Sequence
[
slice
],
mode
:
str
,
**
kwargs
,
)
->
typing
.
Tuple
[
np
.
ndarray
,
typing
.
Tuple
[
int
],
typing
.
Tuple
[
slice
]]:
"""
Extract patch from data and pad accordingly
...
...
@@ -429,8 +444,14 @@ def _padded_crop(data: np.ndarray,
padding
.
append
((
lower_pad
,
upper_pad
))
clipped_crop
.
append
(
slice
(
lower_bound
,
upper_bound
,
crop_dim
.
step
))
origin
=
[
int
(
x
.
start
)
for
x
in
crop
]
return
(
np
.
pad
(
data
[
tuple
([...,
*
clipped_crop
])],
pad_width
=
padding
,
mode
=
mode
,
**
kwargs
),
origin
,
clipped_crop
,
)
# return (np.pad(data[tuple([..., *clipped_crop])], pad_width=padding, mode=mode, **kwargs),
# origin,
# clipped_crop,
# )
return
(
np
.
pad
(
data
[
tuple
([...,
*
clipped_crop
])],
pad_width
=
padding
,
mode
=
mode
,
**
kwargs
),
origin
,
crop
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment