Commit 7246044d authored by mibaumgartner's avatar mibaumgartner
Browse files

Merge remote-tracking branch 'origin/master' into main

parents fcec502f 6f4c3333
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import shutil
import numpy as np
import SimpleITK as sitk
from pathlib import Path
from typing import Dict, List, Sequence, Optional
from nndet.io.paths import Pathlike
from loguru import logger
from sklearn.model_selection import train_test_split
from nndet.io.paths import get_case_ids_from_dir
from nndet.io.load import save_json
from nndet.utils.clustering import seg2instances, remove_classes, reorder_classes
__all__ = ["maybe_split_4d_nifti", "instances_from_segmentation", "sitk_copy_metadata"]
def maybe_split_4d_nifti(source_file: Path, output_folder: Path):
"""
Process a single nifti file
if 3D File: copies file to target location
if 4D File: splits into multiple 3D files and append _0000 ending to indicate channels
Args:
source_file (Path): path to source file
output_folder (Path): path to target directory
Raises
TypeError: Data must be 3D or 4D
"""
img_itk = sitk.ReadImage(str(source_file))
dim = img_itk.GetDimension()
filename = source_file.name
if dim == 3:
# -7 cuts the .nii.gz part
shutil.copy(str(source_file), str(output_folder / (filename[:-7] + "_0000.nii.gz")))
return
elif dim == 4:
imgs_splitted = split_4d_itk(img_itk)
for idx, img in enumerate(imgs_splitted):
sitk.WriteImage(img, str(output_folder / (filename[:-7] + "_%04.0d.nii.gz" % idx)))
else:
raise TypeError(f"Unexpected dimensionality: {dim} of file {source_file}, cannot split")
def split_4d_itk(img_itk: sitk.Image) -> List[sitk.Image]:
"""
Helper function to split 4d itk images into multiple 3 images
Args:
img_itk: 4D input image
Returns:
List[sitk.Image]: 3d output images
"""
img_npy = sitk.GetArrayFromImage(img_itk)
spacing = img_itk.GetSpacing()
origin = img_itk.GetOrigin()
direction = np.array(img_itk.GetDirection()).reshape(4, 4)
spacing = tuple(list(spacing[:-1]))
assert len(spacing) == 3
origin = tuple(list(origin[:-1]))
assert len(origin) == 3
direction = tuple(direction[:-1, :-1].reshape(-1))
assert len(direction) == 9
images_new = []
for i, t in enumerate(range(img_npy.shape[0])):
img = img_npy[t]
images_new.append(
create_itk_image_spatial_props(img, spacing, origin, direction))
return images_new
def create_itk_image_spatial_props(
data: np.ndarray, spacing: Sequence[float], origin: Sequence[float],
direction: Sequence[Sequence[float]]) -> sitk.Image:
"""
Create new sitk image and set spatial tags
Args:
data: data
spacing: spacing
origin: origin
direction: directiont
Returns:
sitk.Image: new image
"""
data_itk = sitk.GetImageFromArray(data)
data_itk.SetSpacing(spacing)
data_itk.SetOrigin(origin)
data_itk.SetDirection(direction)
return data_itk
def sitk_copy_metadata(img_source: sitk.Image, img_target: sitk.Image) -> sitk.Image:
"""
Copy metadata (spacing, origin, direction) from source to target image
Args
img_source: source image
img_target: target image
Returns:
SimpleITK.Image: target image with copied metadata
"""
raise RuntimeError("Deprecated")
spacing = img_source.GetSpacing()
img_target.SetSpacing(spacing)
origin = img_source.GetOrigin()
img_target.SetOrigin(origin)
direction = img_source.GetDirection()
img_target.SetDirection(direction)
return img_target
def instances_from_segmentation(source_file: Path, output_folder: Path,
rm_classes: Sequence[int] = None,
ro_classes: Dict[int, int] = None,
subtract_one_of_classes: bool = True,
fg_vs_bg: bool = False,
file_name: Optional[str] = None
):
"""
1. Optionally removes classes from the segmentation (
e.g. organ segmentation's which are not useful for detection)
2. Optionally reorders the segmentation indices
3. Converts semantic segmentation to instance segmentation's via
connected components
Args:
source_file: path to semantic segmentation file
output_folder: folder where processed file will be saved
rm_classes: classes to remove from semantic segmentation
ro_classes: reorder classes before instances are generated
subtract_one_of_classes: subtracts one from the classes
in the instance mapping (detection networks assume
that classes start from 0)
fg_vs_bg: map all foreground classes to a single class to run
foreground vs background detection task.
file_name: name of saved file (without file type!)
"""
if subtract_one_of_classes and fg_vs_bg:
logger.info("subtract_one_of_classes will be ignored because fg_vs_bg is "
"active and all foreground classes ill be mapped to 0")
seg_itk = sitk.ReadImage(str(source_file))
seg_npy = sitk.GetArrayFromImage(seg_itk)
if rm_classes is not None:
seg_npy = remove_classes(seg_npy, rm_classes)
if ro_classes is not None:
seg_npy = reorder_classes(seg_npy, ro_classes)
instances, instance_classes = seg2instances(seg_npy)
if fg_vs_bg:
num_instances_check = len(instance_classes)
seg_npy[seg_npy > 0] = 1
instances, instance_classes = seg2instances(seg_npy)
num_instances = len(instance_classes)
if num_instances != num_instances_check:
logger.warning(f"Lost instance: Found {num_instances} instances before "
f"fg_vs_bg but {num_instances_check} instances after it")
if subtract_one_of_classes:
for key in instance_classes.keys():
instance_classes[key] -= 1
if fg_vs_bg:
for key in instance_classes.keys():
instance_classes[key] = 0
seg_itk_new = sitk.GetImageFromArray(instances)
seg_itk_new = sitk_copy_metadata(seg_itk, seg_itk_new)
if file_name is None:
suffix_length = sum(map(len, source_file.suffixes))
file_name = source_file.name[:-suffix_length]
save_json({"instances": instance_classes}, output_folder / f"{file_name}.json")
sitk.WriteImage(seg_itk_new, str(output_folder / f"{file_name}.nii.gz"))
def create_test_split(splitted_dir: Pathlike,
num_modalities: int,
test_size: float = 0.3,
random_state: int = 0,
shuffle: bool = True,
):
"""
Helper function to create an artificial test split from the splitted data
Args:
splitted_dir: path to directory with splitted data. `imagesTr` and
`labelsTr` need to exist beforehand. `imagesTs` and `labelsTs`
will be created automatically.
num_modalities: number of modalities
test_size: size of test set, needs to be a value between 0 and 1
seed: seed for splitting
shuffle: shuffle data
"""
images_tr = Path(splitted_dir) / "imagesTr"
labels_tr = Path(splitted_dir) / "labelsTr"
images_ts = Path(splitted_dir) / "imagesTs"
labels_ts = Path(splitted_dir) / "labelsTs"
if not images_tr.is_dir():
raise ValueError(f"No dir with training images found {images_tr}")
if not labels_tr.is_dir():
raise ValueError(f"No dir with training labels found {labels_tr}")
images_ts.mkdir(parents=True, exist_ok=True)
labels_ts.mkdir(parents=True, exist_ok=True)
case_ids = sorted(get_case_ids_from_dir(images_tr, remove_modality=True))
logger.info(f"Found {len(case_ids)} to split")
train_ids, test_ids = train_test_split(
case_ids, test_size=test_size, random_state=random_state, shuffle=shuffle)
logger.info(f"Using {train_ids} for training and {test_ids} for testing.")
for cid in test_ids:
for modality in range(num_modalities):
shutil.move(images_tr / f"{cid}_{modality:04d}.nii.gz",
images_ts / f"{cid}_{modality:04d}.nii.gz")
shutil.move(labels_tr / f"{cid}.nii.gz", labels_ts / f"{cid}.nii.gz")
if (labels_tr / f"{cid}.json").is_file():
shutil.move(labels_tr / f"{cid}.json", labels_ts / f"{cid}.json")
from nndet.io.transforms.base import (
AbstractTransform,
Compose,
)
from nndet.io.transforms.instances import (
Instances2Boxes,
Instances2Segmentation,
FindInstances,
)
from nndet.io.transforms.utils import (
AddProps2Data,
NoOp,
FilterKeys,
)
from nndet.io.transforms.spatial import (
Mirror,
)
from typing import Any, Sequence
import torch
class AbstractTransform(torch.nn.Module):
def __init__(self, grad: bool = False, **kwargs):
"""
Args:
grad: enable gradient computation inside transformation
"""
super().__init__()
self.grad = grad
def __call__(self, *args, **kwargs) -> Any:
"""
Call super class with correct torch context
Args:
*args: forwarded positional arguments
**kwargs: forwarded keyword arguments
Returns:
Any: transformed data
"""
if self.grad:
context = torch.enable_grad()
else:
context = torch.no_grad()
with context:
return super().__call__(*args, **kwargs)
class Compose(AbstractTransform):
def __init__(self, *transforms):
"""
Compose multiple transforms to one
Args:
transforms: transformations to compose
"""
super().__init__(grad=False)
if len(transforms) == 1 and isinstance(transforms[0], Sequence):
transforms = transforms[0]
self.transforms = torch.nn.ModuleList(list(transforms))
def forward(self, **batch):
"""
Augment batch
"""
for t in self.transforms:
batch = t(**batch)
return batch
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import numpy as np
from torch import Tensor
from typing import Dict, Union, Sequence, Tuple, Optional
from nndet.io.transforms.base import AbstractTransform
class FindInstances(AbstractTransform):
def __init__(self, instance_key: str, save_key: str = "present_instances", **kwargs):
super().__init__(grad=False)
self.instance_key = instance_key
self.save_key = save_key
def forward(self, **data) -> dict:
present_instances = []
for instance_element in data[self.instance_key].split(1):
tmp = instance_element.to(dtype=torch.int).unique(sorted=True)
tmp = tmp[tmp > 0]
present_instances.append(tmp)
data[self.save_key] = present_instances
return data
class Instances2Boxes(AbstractTransform):
def __init__(self, instance_key: str, map_key: str,
box_key: str, class_key: str, grad: bool = False,
present_instances: Optional[str] = None,
**kwargs):
"""
Convert instance segmentation to bounding boxes
Args
instance_key: key where instance segmentation is located
map_key: key where mapping from instances to classes is located
(should be a dict which keys(instances) to items(classes))
box_key: key where boxes should be saved
class_key: key where classes of instances will be saved
grad: enable gradient computation inside transformation
present_instances: key where precomputed present instances are
saved. If None it will compute the present instance new.
"""
super().__init__(grad=grad, **kwargs)
self.class_key = class_key
self.box_key = box_key
self.map_key = map_key
self.instance_key = instance_key
self.present_instances = present_instances
def forward(self, **data) -> dict:
"""
Extract boxes from instances
Args:
**data: batch dict
Returns:
dict: processed batch
"""
data[self.box_key] = []
data[self.class_key] = []
for batch_idx, instance_element in enumerate(data[self.instance_key].split(1)):
_present_instances = data[self.present_instances][batch_idx] if self.present_instances is not None else None
_boxes, instance_idx = instances_to_boxes(
instance_element, instance_element.ndim - 2, instances=_present_instances)
_classes = get_instance_class_from_properties(
instance_idx, data[self.map_key][batch_idx])
_classes = _classes.to(device=_boxes.device)
data[self.box_key].append(_boxes)
data[self.class_key].append(_classes)
return data
def instances_to_boxes(seg: Tensor,
dim: int = None,
instances: Optional[Sequence[int]] = None,
) -> Tuple[Tensor, Tensor]:
"""
Convert instance segmentation to bounding boxes (not batched)
Args
seg: instance segmentation of individual classes [..., dims]
dim: number of spatial dimensions to create bounding box for
(always start from the last dimension). If None, all dimensions are
used
Returns
Tensor: bounding boxes
(x1, y1, x2, y2, (z1, z2)) List[Tensor[N, dim * 2]]
Tensor: tuple with classes for bounding boxes
"""
if dim is None:
dim = seg.ndim
boxes = []
_seg = seg.detach()
if instances is None:
instances = _seg.unique(sorted=True)
instances = instances[instances > 0]
for _idx in instances:
instance_idx = (_seg == _idx).nonzero(as_tuple=False)
_mins = instance_idx[:, -3:].min(dim=0)[0]
_maxs = instance_idx[:, -3:].max(dim=0)[0]
box = [_mins[-dim] - 1, _mins[(-dim) + 1] - 1, _maxs[-dim] + 1, _maxs[(-dim) + 1] + 1]
if dim > 2:
box = box + [_mins[(-dim) + 2] - 1, _maxs[(-dim) + 2] + 1]
boxes.append(torch.tensor(box))
if boxes:
boxes = torch.stack(boxes)
else:
boxes = torch.tensor([[]])
return boxes.to(dtype=torch.float, device=seg.device), instances
def instances_to_boxes_np(
seg: np.ndarray,
dim: int = None,
instances: Optional[Sequence[int]] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Convert instance segmentation to bounding boxes (not batched)
Args
seg: instance segmentation of individual classes [..., dims]
dim: number of spatial dimensions to create bounding box for
(always start from the last dimension). If None, all dimensions are
used
Returns
np.ndarray: bounding boxes
(x1, y1, x2, y2, (z1, z2)) List[Tensor[N, dim * 2]]
np.ndarray: tuple with classes for bounding boxes
"""
if dim is None:
dim = seg.ndim
boxes = []
if instances is None:
instances = np.unique(seg)
instances = instances[instances > 0]
for _idx in instances:
instance_idx = np.stack(np.nonzero(seg == _idx), axis=1)
_mins = np.min(instance_idx[:, -dim:], axis=0)
_maxs = np.max(instance_idx[:, -dim:], axis=0)
box = [_mins[-dim] - 1, _mins[(-dim) + 1] - 1, _maxs[-dim] + 1, _maxs[(-dim) + 1] + 1]
if dim > 2:
box = box + [_mins[(-dim) + 2] - 1, _maxs[(-dim) + 2] + 1]
boxes.append(np.array(box))
if boxes:
boxes = np.stack(boxes)
else:
boxes = np.array([[]])
return boxes, instances
def get_instance_class_from_properties(
instance_idx: torch.Tensor, map_dict: Dict[str, Union[str, int]]) -> Tensor:
"""
Extract instance classes form mapping dict
Args:
instance_idx: instance ids present in segmentaion
map_dict: dict mapping instance ids (keys) to classes
Returns:
Tensor: extracted instance classes
"""
instance_idx, _ = instance_idx.sort()
classes = [int(map_dict[str(int(idx.detach().item()))]) for idx in instance_idx]
return torch.tensor(classes, device=instance_idx.device)
def get_instance_class_from_properties_seq(
instance_idx: Sequence, map_dict: Dict[str, Union[str, int]]) -> Sequence:
"""
Extract instance classes form mapping dict
Args:
instance_idx: instance ids present in segmentaion
map_dict: dict mapping instance ids (keys) to classes
Returns:
Sequence[int]: extracted instance classes
"""
instance_idx = sorted(instance_idx)
classes = [int(map_dict[str(int(idx))]) for idx in instance_idx]
return classes
class Instances2Segmentation(AbstractTransform):
def __init__(self, instance_key: str, map_key: str, seg_key: str = None,
add_background: bool = True, grad: bool = False,
present_instances: Optional[str] = None,
):
"""
Convert instances to semantic segmentation
Args:
instance_key: key where instance segmentation is located
map_key: key where mapping from instances to classes is located
seg_key: key where segmentation should be saved; If None, the
instance key will be overwritten
add_background: adds +1 to classes from mapping for background
grad: enable gradient propagation through transformation
present_instances: key where precomputed present instances are
saved. If None it will compute the present instance new.
"""
super().__init__(grad=grad)
self.add_background = add_background
self.seg_key = seg_key if seg_key is not None else instance_key
self.map_key = map_key
self.instance_key = instance_key
self.present_instances = present_instances
def forward(self, **data) -> dict:
"""
Convert instance segmentation to semantic segmentation
Args:
**data: batch dict
Returns:
dict: processed batch
"""
semantic = torch.zeros_like(data[self.instance_key])
_present_instances = data[self.present_instances] if self.present_instances is not None else None
for batch_idx in range(semantic.shape[0]):
instances_to_segmentation(data[self.instance_key][batch_idx],
data[self.map_key][batch_idx],
add_background=self.add_background,
instance_idx=_present_instances[batch_idx],
out=semantic[batch_idx])
data[self.seg_key] = semantic
return data
def instances_to_segmentation(instances: Tensor,
mapping: Dict[str, Union[str, int]],
add_background: bool = True,
instance_idx: Optional[Sequence[int]] = None,
out: Tensor = None) -> Tensor:
"""
Convert instances to semantic segmentation
Args:
instances: instance segmentation; foreground classes > 0; [dims]
mapping: mapping from each instance to class
add_background: adds +1 to classes from mapping for background
Should be enabled if classes in mapping start from zero and
diabled otherwise
out: optional output tensor where results are saved
instance_idx: precomputed instance ids present in sample. If None
the instances ids will be computed
Returns:
Tensor: semantic segmentation
"""
mapping = {int(key): int(item) for key, item in mapping.items()}
if out is None:
out = torch.zeros_like(instances)
if instance_idx is None:
instance_idx = instances.unique(sorted=True)
instance_idx = instance_idx[instance_idx > 0]
for instance_id in instance_idx:
_cls = mapping[instance_id.item()]
if add_background:
_cls += 1
out[instances == instance_id] = _cls
return out
def instances_to_segmentation_np(instances: np.ndarray,
mapping: Dict[Union[str, int], Union[str, int]],
add_background: bool = True,
out: np.ndarray = None) -> np.ndarray:
"""
Convert instances to semantic segmentation
Args:
instances: instance segmentation; foreground classes > 0; [dims]
mapping: mapping from each instance to class
add_background: adds +1 to classes from mapping for background
Should be enabled if classes in mapping start from zero and
diabled otherwise
out: optional output tensor where results are saved
Returns:
Tensor: semantic segmentation
"""
mapping = {int(key): int(item) for key, item in mapping.items()}
if out is None:
out = np.zeros_like(instances)
instance_idx = np.unique(instances)
instance_idx = instance_idx[instance_idx > 0]
for instance_id in instance_idx:
_cls = mapping[instance_id]
if add_background:
_cls += 1
out[instances == instance_id] = _cls
return out
def get_bbox_np(seg: np.ndarray,
map_dict: Optional[Dict[Union[str, int], Union[str, int]]] = None,
**kwargs,
) -> dict:
"""
Get bounding boxes and mapping from instances to classes
Args:
seg: instance segmentation [1, dims]
mapping: define mapping from instance ids to classes
Returns:
dict: extracted boxes and classes
`boxes` (np.ndarray): bounding boxes [N, dims * 2]
`classes` (np.ndarray): classes (in same order as boxes) [N]
"""
if map_dict is not None:
map_dict = {str(key): str(item) for key, item in map_dict.items()}
result = {}
boxes, instance_idx = instances_to_boxes_np(seg[0], **kwargs)
result["boxes"] = boxes
if map_dict is not None:
box_classes = get_instance_class_from_properties_seq(instance_idx, map_dict)
result["classes"] = np.array(box_classes)
return result
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from torch import Tensor
from typing import Sequence, List
from nndet.io.transforms.base import AbstractTransform
class Mirror(AbstractTransform):
def __init__(self, keys: Sequence[str], dims: Sequence[int],
point_keys: Sequence[str] = (), box_keys: Sequence[str] = (),
grad: bool = False):
"""
Mirror Transform
Args:
keys: keys to mirror (first key must correspond to data for
shape information) expected shape [N, C, dims]
dims: dimensions to mirror (starting from the first spatial
dimension)
point_keys: keys where points for transformation are located
[N, dims]
box_keys: keys where boxes are located; following format
needs to be used (x1, y1, x2, y2, (z1, z2)) [N, dims * 2]
grad: enable gradient computation inside transformation
"""
super().__init__(grad=grad)
self.dims = dims
self.keys = keys
self.point_keys = point_keys
self.box_keys = box_keys
def forward(self, **data) -> dict:
"""
Implement transform functionality here
Args
data: dict with data
Returns
dict: dict with transformed data
"""
for key in self.keys:
data[key] = mirror(data[key], self.dims)
data_shape = data[self.keys[0]].shape
data_shapes = [tuple(data_shape[2:])] * data_shape[0]
for key in self.box_keys:
points = [boxes2points(b) for b in data[key]]
points = mirror_points(points, self.dims, data_shapes)
data[key] = [points2boxes(p) for p in points]
for key in self.point_keys:
data[key] = mirror_points(data[key], self.dims, data_shapes)
return data
def invert(self, **data) -> dict:
"""
Revert mirroring
Args:
**data: dict with data
Returns:
dict with re-transformed data
"""
return self(**data)
def mirror(data: torch.Tensor, dims: Sequence[int]) -> torch.Tensor:
"""
Mirror data at dims
Args
data: input data [N, C, spatial dims]
dims: dimensions to mirror starting from spatial dims
e.g. dim=(0,) mirror the first spatial dimension
Returns
torch.Tensor: tensor with mirrored dimensions
"""
dims = [d + 2 for d in dims]
return data.flip(dims)
def mirror_points(points: Sequence[torch.Tensor], dims: Sequence[int],
data_shapes: Sequence[Sequence[int]]) -> List[torch.Tensor]:
"""
Mirror points along given dimensions
Args:
points: points per batch element [N, dims]
dims: dimensions to mirror
data_shapes: shape of data
Returns:
Tensor: transformed points [N, dims]
"""
cartesian_dims = points[0].shape[1]
homogeneous_points = points_to_homogeneous(points)
transformed = []
for points_per_image, data_shape in zip(homogeneous_points, data_shapes):
matrix = nd_mirror_matrix(cartesian_dims, dims, data_shape).to(points_per_image)
transformed.append(points_per_image @ matrix.transpose(0, 1))
return points_to_cartesian(transformed)
def nd_mirror_matrix(cartesian_dims: int, mirror_dims: Sequence[int],
data_shape: Sequence[int]) -> torch.Tensor:
"""
Create n dimensional matrix to for mirroring
Args:
cartesian_dims: number of cartesian dimensions
mirror_dims: dimensions to mirror
data_shape: shape of image
Returns:
Tensor: matrix for mirroring in homogeneous coordinated,
[cartesian_dims + 1, cartesian_dims + 1]
"""
mirror_dims = tuple(mirror_dims)
data_shape = list(data_shape)
homogeneous_dims = cartesian_dims + 1
mat = torch.eye(homogeneous_dims, dtype=torch.float)
# reflection
mat[[mirror_dims] * 2] = -1
# add data shape to axis which were reflected
self_tensor = torch.zeros(cartesian_dims, dtype=torch.float)
index_tensor = torch.Tensor(mirror_dims).long()
src_tensor = torch.tensor([1] * len(mirror_dims), dtype=torch.float)
offset_mask = self_tensor.scatter_(0, index_tensor, src_tensor)
mat[:-1, -1] = offset_mask * torch.tensor(data_shape)
return mat
def points_to_homogeneous(points: Sequence[torch.Tensor]) -> List[torch.Tensor]:
"""
Transforms points from cartesian to homogeneous coordinates
Args:
points: list of points to transform [N, dims] where N is the number
of points and dims is the number of spatial dimensions
Returns
torch.Tensor: the batch of points in homogeneous coordinates [N, dim + 1]
"""
return [torch.cat([p, torch.ones(p.shape[0], 1).to(p)], dim=1) for p in points]
def points_to_cartesian(points: Sequence[torch.Tensor]) -> List[torch.Tensor]:
"""
Transforms points in homogeneous coordinates back to cartesian
coordinates.
Args:
points: homogeneous points [N, in_dims], N number of points,
in_dims number of input dimensions (spatial dimensions + 1)
Returns:
List[Tensor]]: cartesian points [N, in_dims] = [N, dims]
"""
return [p[..., :-1] / p[..., -1][:, None] for p in points]
def boxes2points(boxes: Tensor) -> Tensor:
"""
Convert boxes to points
Args:
boxes: (x1, y1, x2, y2, (z1, z2))[N, dims *2]
Returns:
Tensor: points [N * 2, dims]
"""
if boxes.shape[1] == 4:
idx0 = [0, 1]
idx1 = [2, 3]
else:
idx0 = [0, 1, 4]
idx1 = [2, 3, 5]
points0 = boxes[:, idx0]
points1 = boxes[:, idx1]
return torch.cat([points0, points1], dim=0)
def points2boxes(points: Tensor) -> Tensor:
"""
Convert points to boxes
Args:
points: boxes need to be order as specified
order: [point_box_0, ... point_box_N/2] * 4
format of points: (x, y(, z)))[N, dims]
Returns:
Tensor: bounding boxes [N / 2, dims * 2]
"""
if points.nelement() > 0:
points0, points1 = points.split(points.shape[0] // 2)
boxes = torch.zeros(points.shape[0] // 2, points.shape[1] * 2).to(
device=points.device, dtype=points.dtype)
boxes[:, 0] = torch.min(points0[:, 0], points1[:, 0])
boxes[:, 1] = torch.min(points0[:, 1], points1[:, 1])
boxes[:, 2] = torch.max(points0[:, 0], points1[:, 0])
boxes[:, 3] = torch.max(points0[:, 1], points1[:, 1])
if boxes.shape[1] == 6:
boxes[:, 4] = torch.min(points0[:, 2], points1[:, 2])
boxes[:, 5] = torch.max(points0[:, 2], points1[:, 2])
return boxes
else:
return torch.tensor([]).view(-1, points.shape[1] * 2).to(points)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Hashable, Mapping, Sequence
from nndet.io.transforms.base import AbstractTransform
class AddProps2Data(AbstractTransform):
def __init__(self, props_key: str, key_mapping: Mapping[str, str], **kwargs):
"""
Move properties from property dict to data dict
Args
props_key: key where properties and :param:`map_key` key is located;
key_mapping: maps properties(key) to new keys in data dict(item)
"""
super().__init__(grad=False, **kwargs)
self.key_mapping = key_mapping
self.props_key = props_key
def forward(self, **data) -> dict:
"""
Move keys from properties to data
Args:
**data: batch dict
Returns:
dict: updated batch
"""
props = data[self.props_key]
for source, target in self.key_mapping.items():
data[target] = [p[source] for p in props]
return data
class NoOp(AbstractTransform):
def __init__(self, grad: bool = False):
"""
Forward input without change
Args:
grad: propagate gradient through transformation
"""
super().__init__(grad=grad)
def forward(self, **data) -> dict:
"""
NoOp
"""
return data
def invert(self, **data) -> dict:
"""
NoOp
"""
return data
class FilterKeys(AbstractTransform):
def __init__(self, keys: Sequence[Hashable]):
super().__init__(grad=False)
self.keys = keys
def forward(self, **data) -> dict:
return {k: data[k] for k in self.keys}
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import List
from loguru import logger
from collections import OrderedDict
from pathlib import Path
from nndet.io.load import load_pickle
from nndet.io.paths import get_case_ids_from_dir, get_case_id_from_path, Pathlike
def get_np_paths_from_dir(directory: Pathlike) -> List[str]:
"""
First looks for npz files inside dir. If no files are found, it looks
for npy files.
Args:
directory: path to folder
Raises:
RuntimeError: raised if no npy and no npz files are found
Returns:
List[str]: paths to files
"""
case_paths = get_case_ids_from_dir(
Path(directory), remove_modality=False, join=True, pattern="*.npy")
if not case_paths:
logger.info(f"Did not find any npy files, looking for npz files. Folder: {directory}")
case_paths = get_case_ids_from_dir(
Path(directory), remove_modality=False, join=True, pattern="*.npz")
if not case_paths:
logger.error(f"Did not find any npz files.")
raise RuntimeError(f"Did not find any npz files. Folder: {directory}")
case_paths = [f for f in case_paths if "_seg" not in f]
case_paths.sort()
return case_paths
def load_dataset(folder: Pathlike) -> dict:
"""
Load dataset (path and properties, NOT the actual data) and
save them into dict by their path
Args:
folder: folder to look for data
Raises:
RuntimeError: data needs to be provided in npy or npz format
Returns:
dict: loaded data
"""
folder = Path(folder)
case_identifiers = get_np_paths_from_dir(folder)
dataset = OrderedDict()
for c in case_identifiers:
dataset[c] = OrderedDict()
dataset[c]['data_file'] = str(folder / f"{c}.npy")
dataset[c]['seg_file'] = str(folder / f"{c}_seg.npy")
dataset[c]['properties_file'] = str(folder / f"{c}.pkl")
dataset[c]['boxes_file'] = str(folder / f"{c}_boxes.pkl")
return dataset
def load_dataset_id(folder: Pathlike) -> dict:
"""
Load dataset (path and properties, NOT the actual data) and
save them into dict by their identifier
Args:
folder: folder to look for data
Raises:
RuntimeError: data needs to be provided in npy or npz format
Returns:
dict: loaded data
"""
folder = Path(folder)
case_paths = get_np_paths_from_dir(folder)
case_ids = [get_case_id_from_path(c, remove_modality=False) for c in case_paths]
dataset = OrderedDict()
for c in case_ids:
dataset[c] = OrderedDict()
dataset[c]['data_file'] = str(folder / f"{c}.npy")
dataset[c]['data_file'] = str(folder / f"{c}.npy")
dataset[c]['seg_file'] = str(folder / f"{c}_seg.npy")
dataset[c]['properties_file'] = str(folder / f"{c}.pkl")
dataset[c]['boxes_file'] = str(folder / f"{c}_boxes.pkl")
return dataset
from nndet.losses.classification import focal_loss_with_logits, FocalLossWithLogits
from nndet.losses.regression import SmoothL1Loss, smooth_l1_loss, GIoULoss
from nndet.losses.segmentation import SoftDiceLoss
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
__all__ = ["reduction_helper"]
def reduction_helper(data: torch.Tensor, reduction: str) -> torch.Tensor:
"""
Helper to collapse data with different modes
Args:
data: data to collapse
reduction: type of reduction. One of `mean`, `sum`, None
Returns:
Tensor: reduced data
"""
if reduction == 'mean':
return torch.mean(data)
if reduction == 'none' or reduction is None:
return data
if reduction == 'sum':
return torch.sum(data)
raise AttributeError('Reduction parameter unknown.')
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch import Tensor
from loguru import logger
from nndet.losses.base import reduction_helper
from nndet.utils import make_onehot_batch
def one_hot_smooth(data,
num_classes: int,
smoothing: float = 0.0,
):
targets = torch.empty(size=(*data.shape, num_classes), device=data.device)\
.fill_(smoothing / num_classes)\
.scatter_(-1, data.long().unsqueeze(-1), 1. - smoothing)
return targets
@torch.jit.script
def focal_loss_with_logits(
logits: torch.Tensor,
target: torch.Tensor, gamma: float,
alpha: float = -1,
reduction: str = "mean",
) -> torch.Tensor:
"""
Focal loss
https://arxiv.org/abs/1708.02002
Args:
logits: predicted logits [N, dims]
target: (float) binary targets [N, dims]
gamma: balance easy and hard examples in focal loss
alpha: balance positive and negative samples [0, 1] (increasing
alpha increase weight of foreground classes (better recall))
reduction: 'mean'|'sum'|'none'
mean: mean of loss over entire batch
sum: sum of loss over entire batch
none: no reduction
Returns:
torch.Tensor: loss
See Also
:class:`BFocalLossWithLogits`, :class:`FocalLossWithLogits`
"""
bce_loss = F.binary_cross_entropy_with_logits(logits, target, reduction='none')
p = torch.sigmoid(logits)
pt = (p * target + (1 - p) * (1 - target))
focal_term = (1. - pt).pow(gamma)
loss = focal_term * bce_loss
if alpha >= 0:
alpha_t = (alpha * target + (1 - alpha) * (1 - target))
loss = alpha_t * loss
return reduction_helper(loss, reduction=reduction)
class FocalLossWithLogits(nn.Module):
def __init__(self,
gamma: float = 2,
alpha: float = -1,
reduction: str = "sum",
loss_weight: float = 1.,
):
"""
Focal loss with multiple classes (uses one hot encoding and sigmoid)
Args:
gamma: balance easy and hard examples in focal loss
alpha: balance positive and negative samples [0, 1] (increasing
alpha increase weight of foreground classes (better recall))
reduction: 'mean'|'sum'|'none'
mean: mean of loss over entire batch
sum: sum of loss over entire batch
none: no reduction
loss_weight: scalar to balance multiple losses
"""
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
logits: torch.Tensor,
targets: torch.Tensor,
) -> torch.Tensor:
"""
Compute loss
Args:
logits: predicted logits [N, C, dims], where N is the batch size,
C number of classes, dims are arbitrary spatial dimensions
(background classes should be located at channel 0 if
ignore background is enabled)
targets: targets encoded as numbers [N, dims], where N is the
batch size, dims are arbitrary spatial dimensions
Returns:
torch.Tensor: loss
"""
n_classes = logits.shape[1] + 1
target_onehot = make_onehot_batch(targets, n_classes=n_classes).float()
target_onehot = target_onehot[:, 1:]
return self.loss_weight * focal_loss_with_logits(
logits, target_onehot,
gamma=self.gamma,
alpha=self.alpha,
reduction=self.reduction,
)
class BCEWithLogitsLossOneHot(torch.nn.BCEWithLogitsLoss):
def __init__(self,
*args,
num_classes: int,
smoothing: float = 0.0,
loss_weight: float = 1.,
**kwargs,
):
"""
BCE loss with one hot encoding of targets
Args:
num_classes: number of classes
smoothing: label smoothing
loss_weight: scalar to balance multiple losses
"""
super().__init__(*args, **kwargs)
self.smoothing = smoothing
if smoothing > 0:
logger.info(f"Running label smoothing with smoothing: {smoothing}")
self.num_classes = num_classes
self.loss_weight = loss_weight
def forward(self,
input: Tensor,
target: Tensor,
) -> Tensor:
"""
Compute bce loss based on one hot encoding
Args:
input: logits for all foreground classes [N, C]
N is the number of anchors, and C is the number of foreground
classes
target: target classes. 0 is treated as background, >0 are
treated as foreground classes. [N] is the number of anchors
Returns:
Tensor: final loss
"""
target_one_hot = one_hot_smooth(
target, num_classes=self.num_classes + 1, smoothing=self.smoothing) # [N, C + 1]
target_one_hot = target_one_hot[:, 1:] # background is implicitly encoded
return self.loss_weight * super().forward(input, target_one_hot.float())
class CrossEntropyLoss(torch.nn.CrossEntropyLoss):
def __init__(self,
*args,
loss_weight: float = 1.,
**kwargs,
) -> None:
"""
Same as CE from pytorch with additional loss weight for uniform API
"""
super().__init__(*args, **kwargs)
self.loss_weight = loss_weight
def forward(self,
input: Tensor,
target: Tensor,
) -> Tensor:
"""
Same as CE from pytorch
"""
return self.loss_weight * super().forward(input, target)
from typing import Optional
import torch
__all__ = ["SmoothL1Loss", "smooth_l1_loss"]
from nndet.core.boxes.ops import generalized_box_iou
from nndet.losses.base import reduction_helper
class SmoothL1Loss(torch.nn.Module):
def __init__(self,
beta: float,
reduction: Optional[str] = None,
loss_weight: float = 1.,
):
"""
Module wrapper for functional
Args:
beta (float): L1 to L2 change point.
For beta values < 1e-5, L1 loss is computed.
reduction (str): 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
See Also:
:func:`smooth_l1_loss`
"""
super().__init__()
self.reduction = reduction
self.beta = beta
self.loss_weight = loss_weight
def forward(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss
Args:
inp (torch.Tensor): predicted tensor (same shape as target)
target (torch.Tensor): target tensor
Returns:
Tensor: computed loss
"""
return self.loss_weight * reduction_helper(smooth_l1_loss(inp, target, self.beta), self.reduction)
def smooth_l1_loss(inp, target, beta: float):
"""
From https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/smooth_l1_loss.py
Smooth L1 loss defined in the Fast R-CNN paper as:
| 0.5 * x ** 2 / beta if abs(x) < beta
smoothl1(x) = |
| abs(x) - 0.5 * beta otherwise,
where x = input - target.
Smooth L1 loss is related to Huber loss, which is defined as:
| 0.5 * x ** 2 if abs(x) < beta
huber(x) = |
| beta * (abs(x) - 0.5 * beta) otherwise
Smooth L1 loss is equal to huber(x) / beta. This leads to the following
differences:
- As beta -> 0, Smooth L1 loss converges to L1 loss, while Huber loss
converges to a constant 0 loss.
- As beta -> +inf, Smooth L1 converges to a constant 0 loss, while Huber loss
converges to L2 loss.
- For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant
slope of 1. For Huber loss, the slope of the L1 segment is beta.
Smooth L1 loss can be seen as exactly L1 loss, but with the abs(x) < beta
portion replaced with a quadratic function such that at abs(x) = beta, its
slope is 1. The quadratic segment smooths the L1 loss near x = 0.
Args:
inp (Tensor): input tensor of any shape
target (Tensor): target value tensor with the same shape as input
beta (float): L1 to L2 change point.
For beta values < 1e-5, L1 loss is computed.
reduction (str): 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
Returns:
Tensor: The loss with the reduction option applied.
Note:
PyTorch's builtin "Smooth L1 loss" implementation does not actually
implement Smooth L1 loss, nor does it implement Huber loss. It implements
the special case of both in which they are equal (beta=1).
See: https://pytorch.org/docs/stable/nn.html#torch.nn.SmoothL1Loss.
"""
if beta < 1e-5:
# if beta == 0, then torch.where will result in nan gradients when
# the chain rule is applied due to pytorch implementation details
# (the False branch "0.5 * n ** 2 / 0" has an incoming gradient of
# zeros, rather than "no gradient"). To avoid this issue, we define
# small values of beta to be exactly l1 loss.
loss = torch.abs(inp - target)
else:
n = torch.abs(inp - target)
cond = n < beta
loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
return loss
class GIoULoss(torch.nn.Module):
def __init__(self,
reduction: Optional[str] = None,
eps: float = 1e-7,
loss_weight: float = 1.,
):
"""
Generalized IoU Loss
`Generalized Intersection over Union: A Metric and A Loss for Bounding
Box Regression` https://arxiv.org/abs/1902.09630
Args:
eps: small constant for numerical stability
Notes:
Original paper uses lambda=10 to balance regression and cls losses
for PASCAL VOC and COCO (not tuned for coco)
`End-to-End Object Detection with Transformers` https://arxiv.org/abs/2005.12872
"Our enhanced Faster-RCNN+ baselines use GIoU [38] loss along with
the standard l1 loss for bounding box regression. We performed a grid search
to find the best weights for the losses and the final models use only GIoU loss
with weights 20 and 1 for box and proposal regression tasks respectively"
"""
super().__init__()
self.eps = eps
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self, pred_boxes: torch.Tensor, target_boxes: torch.Tensor) -> torch.Tensor:
"""
Compute generalized iou loss
Args:
pred_boxes: predicted boxes (x1, y1, x2, y2, (z1, z2)) [N, dim * 2]
target_boxes: target boxes (x1, y1, x2, y2, (z1, z2)) [N, dim * 2]
Returns:
Tensor: loss
"""
loss = reduction_helper(
torch.diag(generalized_box_iou(pred_boxes, target_boxes, eps=self.eps),
diagonal=0),
reduction=self.reduction)
return self.loss_weight * -1 * loss
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from loguru import logger
import torch
import torch.nn as nn
from torch import Tensor
from typing import Callable
def one_hot_smooth_batch(data, num_classes: int, smoothing: float = 0.0):
shape = data.shape
targets = torch.empty(size=(shape[0], num_classes, *shape[1:]), device=data.device)\
.fill_(smoothing / num_classes)\
.scatter_(1, data.long().unsqueeze(1), 1. - smoothing)
return targets
def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
"""
net_output must be (b, c, x, y(, z)))
gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
if mask is provided it must have shape (b, 1, x, y(, z)))
:param net_output:
:param gt:
:param axes:
:param mask: mask must be 1 for valid pixels and 0 for invalid pixels
:param square: if True then fp, tp and fn will be squared before summation
:return:
"""
if axes is None:
axes = tuple(range(2, len(net_output.size())))
shp_x = net_output.shape
shp_y = gt.shape
with torch.no_grad():
if len(shp_x) != len(shp_y):
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(shp_x)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
if mask is not None:
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
if square:
tp = tp ** 2
fp = fp ** 2
fn = fn ** 2
tp = tp.sum(dim=axes, keepdim=False)
fp = fp.sum(dim=axes, keepdim=False)
fn = fn.sum(dim=axes, keepdim=False)
return tp, fp, fn
class SoftDiceLoss(nn.Module):
def __init__(self,
nonlin: Callable = None,
batch_dice: bool = False,
do_bg: bool = False,
smooth_nom: float = 1e-5,
smooth_denom: float = 1e-5,
):
"""
Soft dice loss
Args:
nonlin: treat batch as pseudo volume. Defaults to False.
do_bg: include background for dice computation. Defaults to True.
smooth_nom: smoothing for nominator
smooth_denom: smoothing for denominator
"""
super().__init__()
self.do_bg = do_bg
self.batch_dice = batch_dice
self.nonlin = nonlin
self.smooth_nom = smooth_nom
self.smooth_denom = smooth_denom
logger.info(f"Running batch dice {self.batch_dice} and "
f"do bg {self.do_bg} in dice loss.")
def forward(self,
inp: torch.Tensor,
target: torch.Tensor,
loss_mask: torch.Tensor=None,
):
"""
Compute loss
Args:
inp (torch.Tensor): predictions
target (torch.Tensor): ground truth
loss_mask ([torch.Tensor], optional): binary mask. Defaults to None.
Returns:
torch.Tensor: soft dice loss
"""
shp_x = inp.shape
if self.batch_dice:
axes = [0] + list(range(2, len(shp_x)))
else:
axes = list(range(2, len(shp_x)))
if self.nonlin is not None:
inp = self.nonlin(inp)
tp, fp, fn = get_tp_fp_fn(inp, target, axes, loss_mask, False)
nominator = 2 * tp + self.smooth_nom
denominator = 2 * tp + fp + fn + self.smooth_denom
dc = nominator / denominator
if not self.do_bg:
if self.batch_dice:
dc = dc[1:]
else:
dc = dc[:, 1:]
dc = dc.mean()
return 1 - dc
class TopKLoss(torch.nn.CrossEntropyLoss):
def __init__(self,
topk: float,
loss_weight: float = 1.,
**kwargs,
):
"""
Uses topk percent of values to compute CE loss
(expects pre softmax logits!)
Args:
topk: percentage of all entries to use for loss computation
loss_weight: scalar to balance multiple losses
"""
if "reduction" in kwargs:
raise ValueError("Reduction is not supported in TopKLoss."
"This will always return the mean!")
super().__init__(
reduction="none",
**kwargs,
)
if topk < 0 or topk > 1:
raise ValueError("topk needs to be in the range [0, 1].")
self.topk = topk
self.loss_weight = loss_weight
def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""
Compute CE loss and uses mean of topk percent of the entries
Args:
input: logits for all foreground classes [N, C, *]
target: target classes. 0 is treated as background, >0 are
treated as foreground classes. [N, *]
Returns:
Tensor: final loss
"""
losses = super().forward(input, target)
k = int(losses.numel() * self.topk)
return self.loss_weight * losses.view(-1).topk(k=k, sorted=False)[0].mean()
class TopKLossSigmoid(torch.nn.BCEWithLogitsLoss):
def __init__(self,
num_classes: int,
topk: float,
smoothing: float = 0.0,
loss_weight: float = 1.,
**kwargs,
):
"""
Uses topk percent of values to compute BCE loss with one hot
(support multi class through one hot, expects pre sigmoid logits!)
Args:
num_classes: number of classes
topk: percentage of all entries to use for loss computation
smoothing: label smoothing
loss_weight: scalar to balance multiple losses
"""
if "reduction" in kwargs:
raise ValueError("Reduction is not supported in TopKLoss."
"This will always return the mean!")
super().__init__(
reduction="none",
**kwargs,
)
self.smoothing = smoothing
if smoothing > 0:
logger.info(f"Running label smoothing with smoothing: {smoothing}")
self.num_classes = num_classes
self.topk = topk
self.loss_weight = loss_weight
def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""
Compute BCE loss based on one hot encoding of foreground(!) classes
and uses mean of topk percent of the entries
Args:
input: logits for all foreground(!) classes [N, C, *]
target: target classes [N, *]. Targets will be encoded with one
hot and 0 is treated as the background class and removed.
Returns:
Tensor: final loss
"""
target_one_hot = one_hot_smooth_batch(
target, num_classes=self.num_classes + 1, smoothing=self.smoothing) # [N, C + 1]
target_one_hot = target_one_hot[:, 1:] # background is implicitly encoded
losses = super().forward(input, target_one_hot.float())
k = int(losses.numel() * self.topk)
return self.loss_weight * losses.view(-1).topk(k=k, sorted=False)[0].mean()
from nndet.planning.analyzer import DatasetAnalyzer
from nndet.planning.experiment import PLANNER_REGISTRY
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import annotations
from os import PathLike
import pickle
from pathlib import Path
from typing import Dict, Sequence, Callable
from nndet.io.paths import get_case_ids_from_dir
class DatasetAnalyzer:
def __init__(self,
cropped_output_dir: PathLike,
preprocessed_output_dir: PathLike,
data_info: dict,
num_processes: int,
overwrite: bool = True,
):
"""
Class to analyse a dataset
:func:`analyze_dataset` saves result into `dataset_properties.pkl`
Args:
cropped_output_dir: path to directory where prepared/cropped data is saved
data_info: additional information about the data
`modalities`: numeric dict which maps modalities to strings (e.g. `CT`)
`labels`: numeric dict which maps segmentation to classes
num_processes: number of processes to use for analysis
overwrite: overwrite existing properties
"""
self.cropped_output_dir = Path(cropped_output_dir)
self.cropped_data_dir = self.cropped_output_dir / "imagesTr"
self.preprocessed_output_dir = Path(preprocessed_output_dir)
self.save_dir = self.preprocessed_output_dir / "properties"
self.save_dir.mkdir(parents=True, exist_ok=True)
self.num_processes = num_processes
self.overwrite = overwrite
self.sizes = self.spacings = None
self.data_info = data_info
self.case_ids = sorted(get_case_ids_from_dir(
self.cropped_output_dir / "imagesTr", pattern="*.npz", remove_modality=False))
self.props_per_case_file = self.save_dir / "props_per_case.pkl"
self.intensity_properties_file = self.save_dir / "intensity_properties.pkl"
def analyze_dataset(self,
properties: Sequence[Callable[[DatasetAnalyzer], Dict]],
) -> Dict:
"""
Analyze dataset
Result is also saved in cropped_output_dir as `dataset_properties.pkl`
Args:
properties: properties to analyze over dataset
Returns:
Dict: filled with computed results
"""
props = {"dim": self.data_info["dim"]}
for property_fn in properties:
props.update(property_fn(self))
with open(self.save_dir / "dataset_properties.pkl", "wb") as f:
pickle.dump(props, f)
return props
from abc import ABC, abstractmethod
from typing import TypeVar
class ArchitecturePlanner(ABC):
def __init__(self, **kwargs):
"""
Plan architecture and training hyperparameters (batch size and patch size)
"""
for key, item in kwargs.items():
setattr(self, key, item)
@abstractmethod
def plan(self, *args, **kwargs) -> dict:
"""
Plan architecture and training parameters
Args:
*args: positional arguments determined by Planner
**kwargs: keyword arguments determined by Planner
Returns:
dict: training and architecture information
`patch_size` (Sequence[int]): patch size
`batch_size` (int): batch size for training
`architecture` (dict): dictionary with all parameters needed for the final model
"""
raise NotImplementedError
def approximate_vram(self):
"""
Approximate vram usage of model for planning
"""
pass
def get_planner_id(self) -> str:
"""
Create identifier for this planner
Returns:
str: identifier
"""
return self.__class__.__name__
ArchitecturePlannerType = TypeVar('ArchitecturePlannerType', bound=ArchitecturePlanner)
from nndet.planning.architecture.boxes.base import BaseBoxesPlanner
from nndet.planning.architecture.boxes.c002 import BoxC002
import os
from pathlib import Path
from abc import abstractmethod
from typing import Type, Dict, Sequence, List, Callable, Tuple
import torch
import numpy as np
from tqdm import tqdm
from loguru import logger
from torchvision.models.detection.rpn import AnchorGenerator
from nnunet.experiment_planning.common_utils import get_pool_and_conv_props
from nndet.io.load import load_pickle
from nndet.arch.abstract import AbstractModel
from nndet.planning.estimator import MemoryEstimator, MemoryEstimatorDetection
from nndet.planning.architecture.abstract import ArchitecturePlanner
from nndet.core.boxes import (
get_anchor_generator,
expand_to_boxes,
box_center,
box_size,
compute_anchors_for_strides,
box_iou,
box_size_np,
box_area_np,
permute_boxes,
)
from nndet.planning.architecture.boxes.utils import (
fixed_anchor_init,
scale_with_abs_strides,
)
class BaseBoxesPlanner(ArchitecturePlanner):
def __init__(self,
preprocessed_output_dir: os.PathLike,
save_dir: os.PathLike,
network_cls: Type[AbstractModel] = None,
estimator: MemoryEstimator = None,
**kwargs,
):
"""
Plan the architecture for training
Args:
min_feature_map_length (int): minimal size of feature map in bottleneck
"""
super().__init__(**kwargs)
self.preprocessed_output_dir = Path(preprocessed_output_dir)
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True, exist_ok=True)
self.network_cls = network_cls
self.estimator = estimator
self.dataset_properties = load_pickle(
self.preprocessed_output_dir / "properties" / 'dataset_properties.pkl')
# parameters initialized from process properties
self.all_boxes: np.ndarray = None
self.all_ious: np.ndarray = None
self.class_ious: Dict[str, np.ndarray] = None
self.num_instances: Dict[int, int] = None
self.dim: int = None
self.architecture_kwargs: dict = {}
self.transpose_forward = None
def process_properties(self, **kwargs):
"""
Load dataset properties and extract information
"""
assert self.transpose_forward is not None
boxes = [case["boxes"] for case_id, case
in self.dataset_properties["instance_props_per_patient"].items()]
self.all_boxes = np.concatenate([b for b in boxes if not isinstance(b, list) and b.size > 0], axis=0)
self.all_boxes = permute_boxes(self.all_boxes, dims=self.transpose_forward)
self.all_ious = self.dataset_properties["all_ious"]
self.class_ious = self.dataset_properties["class_ious"]
self.num_instances = self.dataset_properties["num_instances"]
self.num_instances_per_case = {case_id: sum(case["num_instances"].values())
for case_id, case in self.dataset_properties["instance_props_per_patient"].items()}
self.dim = self.dataset_properties["dim"]
self.architecture_kwargs["classifier_classes"] = \
len(self.dataset_properties["class_dct"])
self.architecture_kwargs["seg_classes"] = \
self.architecture_kwargs["classifier_classes"]
self.architecture_kwargs["in_channels"] = \
len(self.dataset_properties["modalities"])
self.architecture_kwargs["dim"] = \
self.dataset_properties["dim"]
def plot_box_distribution(self, **kwargs):
"""
Plot histogram with ground truth bounding box distribution for
all axis
"""
try:
import matplotlib.pyplot as plt
except ImportError:
plt = None
logger.error("Failed to import matplotlib continue anyway.")
if plt is not None:
if isinstance(self.all_boxes, list):
_boxes = np.concatenate(
[b for b in self.all_boxes if not isinstance(b, list) and b.size > 0], axis=0)
dists = box_size_np(_boxes)
else:
dists = box_size_np(self.all_boxes)
for axis in range(dists.shape[1]):
dist = dists[:, axis]
plt.hist(dist, bins=100)
plt.savefig(
self.save_dir / f'bbox_sizes_axis_{axis}.png')
plt.xscale('log')
plt.savefig(
self.save_dir / f'bbox_sizes_axis_{axis}_xlog.png')
plt.yscale('log')
plt.savefig(
self.save_dir / f'bbox_sizes_axis_{axis}_xylog.png')
plt.close()
def plot_box_area_distribution(self, **kwargs):
"""
Plot histogram of areas of all ground truth boxes
"""
try:
import matplotlib.pyplot as plt
except ImportError:
plt = None
logger.error("Failed to import matplotlib continue anyway.")
if plt is not None:
if isinstance(self.all_boxes, list):
_boxes = np.concatenate(
[b for b in self.all_boxes if not isinstance(b, list) and b.size > 0], axis=0)
area = box_area_np(_boxes)
else:
area = box_area_np(self.all_boxes)
plt.hist(area, bins=100)
plt.savefig(self.save_dir / f'box_areas.png')
plt.xscale('log')
plt.savefig(self.save_dir / f'box_areas_xlog.png')
plt.yscale('log')
plt.savefig(self.save_dir / f'box_areas_xylog.png')
plt.close()
def plot_class_distribution(self, **kwargs):
try:
import matplotlib.pyplot as plt
except ImportError:
plt = None
logger.error("Failed to import matplotlib continue anyway.")
if plt is not None:
num_instances_dict = self.dataset_properties["num_instances"]
num_instances = []
classes = []
for key, item in num_instances_dict.items():
num_instances.append(item)
classes.append(str(key))
ind = np.arange(len(num_instances))
plt.bar(ind, num_instances)
plt.xlabel("Classes")
plt.ylabel("Num Instances")
plt.xticks(ind, classes)
plt.savefig(self.save_dir / f'num_classes.png')
plt.yscale('log')
plt.savefig(self.save_dir / f'num_classes_ylog.png')
plt.close()
def plot_instance_distribution(self, **kwargs):
try:
import matplotlib.pyplot as plt
except ImportError:
plt = None
logger.error("Failed to import matplotlib continue anyway.")
if plt is not None:
num_instances_per_case = list(self.num_instances_per_case.values())
plt.hist(num_instances_per_case, bins=100, range=(0, 100))
plt.savefig(self.save_dir / f'instances_per_case.png')
plt.close()
plt.hist(num_instances_per_case, bins=30, range=(0, 30))
plt.savefig(self.save_dir / f'instances_per_case_0_30.png')
plt.close()
plt.hist(num_instances_per_case, bins=11, range=(0, 11))
plt.savefig(self.save_dir / f'instances_per_case_0_10.png')
plt.close()
@abstractmethod
def _plan_anchors(self) -> dict:
"""
Plan anchors hyperparameters
"""
raise NotImplementedError
@abstractmethod
def _plan_architecture(self) -> Sequence[int]:
"""
Plan architecture
"""
raise NotImplementedError
def plan(self, **kwargs) -> dict:
"""
Plan architecture and training params
"""
for key, item in kwargs.items():
setattr(self, key, item)
self.create_default_settings()
if self.all_boxes is None:
self.process_properties(**kwargs)
self.plot_box_area_distribution(**kwargs)
self.plot_box_distribution(**kwargs)
self.plot_class_distribution(**kwargs)
self.plot_instance_distribution(**kwargs)
return {}
def create_default_settings(self):
pass
def compute_class_weights(self) -> List[float]:
"""
Compute classification weighting for inbalanced datasets
(background samples get weight 1 / (num_classes + 1) and forground
classes are weighted with (1 - 1 / (num_classes + 1))*(1 - ni / nall))
where ni is the number of sampler for class i and n all
is the number of all ground truth samples
Returns:
List[float]: weights
"""
num_instances_dict = self.dataset_properties["num_instances"]
num_classes = len(num_instances_dict)
num_instances = [0] * num_classes
for key, item in num_instances_dict.items():
num_instances[int(key)] = int(item)
bg_weight = 1 / (num_classes + 1)
remaining_weight = 1 - bg_weight
weights = [remaining_weight * (1 - ni / sum(num_instances)) for ni in num_instances]
return [bg_weight] + weights
def get_planner_id(self) -> str:
"""
Create identifier for this planner. If available append
:attr:`plan_tag` to the base name
Returns:
str: identifier
"""
base = super().get_planner_id()
if hasattr(self, "plan_tag"):
base = base + getattr(self, "plan_tag")
return base
class BoxC001(BaseBoxesPlanner):
def __init__(self,
preprocessed_output_dir: os.PathLike,
save_dir: os.PathLike,
network_cls: Callable,
estimator: MemoryEstimator = MemoryEstimatorDetection(),
model_cfg: dict = None,
**kwargs,
):
"""
Plan training architecture with heuristics
Args:
preprocessed_output_dir: base preprocessed directory to
access properties and save analysis files
save_dir: directory to save analysis plots
network_cls: constructor of network to plan
estimator: estimate GPU memory requirements for specific GPU
architectures. Defaults to MemoryEstimatorDetection().
"""
super().__init__(
preprocessed_output_dir=preprocessed_output_dir,
save_dir=save_dir,
network_cls=network_cls,
estimator=estimator,
**kwargs,
)
self.additional_params = {}
if model_cfg is None:
model_cfg = {}
self.model_cfg = model_cfg
self.plan_anchor_for_estimation = fixed_anchor_init(self.dim)
def create_default_settings(self):
"""
Generate some default settings for the architecture
"""
# MAX_NUM_FILTERS_2D, MAX_NUM_FILTERS_3D from nnUNet
self.architecture_kwargs["max_channels"] = 480 if self.dim == 2 else 320
# BASE_NUM_FEATURES_3D from nnUNet
self.architecture_kwargs["start_channels"] = 32
# DEFAULT_BATCH_SIZE_3D from nnUNet
self.batch_size = 32 if self.dim == 2 else 2
self.max_num_pool = 999
self.min_feature_map_size = 4
self.min_decoder_level = 2
self.num_decoder_level = 4
self.architecture_kwargs["fpn_channels"] = \
self.architecture_kwargs["start_channels"] * 2
self.architecture_kwargs["head_channels"] = \
self.architecture_kwargs["fpn_channels"]
def plan(self,
target_spacing_transposed: Sequence[float],
median_shape_transposed: Sequence[float],
transpose_forward: Sequence[int],
mode: str = "3d",
) -> dict:
"""
Plan network architecture, anchors, patch size and batch size
Args:
target_spacing_transposed: spacing after data is transposed and resampled
median_shape_transposed: median shape after data is
transposed and resampled
transpose_forward: new ordering of axes for forward pass
mode: mode to use for planning (this planner only supports 3d!)
Returns:
dict: training and architecture information
See Also:
:method:`_plan_architecture`, :method:`_plan_anchors`
"""
super().plan(
transpose_forward=transpose_forward,
target_spacing_transposed=target_spacing_transposed,
median_shape_transposed=median_shape_transposed,
)
self.architecture_kwargs["class_weight"] = self.compute_class_weights()
patch_size = self._plan_architecture(
transpose_forward=transpose_forward,
target_spacing_transposed=target_spacing_transposed,
target_median_shape_transposed=median_shape_transposed,
)
anchors = self._plan_anchors(
target_spacing_transposed=target_spacing_transposed,
median_shape_transposed=median_shape_transposed,
transpose_forward=transpose_forward,
)
plan = {"patch_size": patch_size,
"batch_size": self.batch_size,
"architecture": {
"arch_name": self.network_cls.__name__,
**self.architecture_kwargs
},
"anchors": anchors,
}
logger.info(f"Using architecture plan: \n{plan}")
return plan
def _plan_anchors(self, **kwargs) -> dict:
"""
Optimize anchors
"""
boxes_np_full = self.all_boxes.astype(np.float32)
boxes_np = self.filter_boxes(boxes_np_full)
logger.info(f"Filtered {boxes_np_full.shape[0] - boxes_np.shape[0]} "
f"boxes, {boxes_np.shape[0]} boxes remaining for anchor "
"planning.")
boxes_torch = torch.from_numpy(boxes_np).float()
boxes_torch = boxes_torch - expand_to_boxes(box_center(boxes_torch))
anchor_generator = get_anchor_generator(self.dim, s_param=True)
rel_strides = self.architecture_kwargs["strides"]
filt_rel_strides = [[1] * self.dim, *rel_strides]
filt_rel_strides = [filt_rel_strides[i] for i in self.architecture_kwargs["decoder_levels"]]
strides = np.cumprod(filt_rel_strides, axis=0) / np.asarray(rel_strides[0])
params = self.find_anchors(boxes_torch, strides.astype(np.int32), anchor_generator)
scaled_params = {key: scale_with_abs_strides(item, strides, dim_idx) for dim_idx, (key, item) in enumerate(params.items())}
logger.info(f"Determined Anchors: {params}; Results in params: {scaled_params}")
self.anchors = scaled_params
self.anchors["stride"] = 1
return self.anchors
@staticmethod
def filter_boxes(boxes_np: np.ndarray,
upper_percentile: float = 99.5,
lower_percentile: float = 00.5,
) -> np.ndarray:
"""
Determine upper and lower percentiles of bounding box sizes for each
axis and remove boxes which are outside the specified range
Args:
boxes_np (np.ndarray): bounding boxes [N, dim * 2](x1, y1, x2, y2, (z1, z2))
upper_percentile: percentile for upper boundary. Defaults to 99.5.
lower_percentile: percentile for lower boundary. Defaults to 00.5.
Returns:
np.ndarray: filtered boxes
See Also:
:func:`np.percentile`
"""
mask = np.ones(boxes_np.shape[0]).astype(bool)
box_sizes = box_size_np(boxes_np)
for ax in range(box_sizes.shape[1]):
ax_sizes = box_sizes[:, ax]
upper_th = np.percentile(ax_sizes, upper_percentile)
lower_th = np.percentile(ax_sizes, lower_percentile)
ax_mask = (ax_sizes < upper_th) * (ax_sizes > lower_th)
mask = mask * ax_mask
return boxes_np[mask.astype(bool)]
def find_anchors(self,
boxes_torch: torch.Tensor,
strides: Sequence[Sequence[int]],
anchor_generator: AnchorGenerator,
) -> Dict[str, Sequence[int]]:
"""
Find anchors which maximize iou over dataset
Args:
boxes_torch: filtered ground truth boxes
strides (Sequence[Sequence[int]]): strides of network to compute
anchor sizes of lower levels
anchor_generator (AnchorGenerator): anchor generator for generate
the anchors
Returns:
Dict[Sequence[int]]: parameterization of anchors
`width` (Sequence[float]): width values for bounding boxes
`height` (Sequence[float]): height values for bounding boxes
(`depth` (Sequence[float]): dpeth values for bounding boxes)
"""
import nevergrad as ng
dim = int(boxes_torch.shape[1] // 2)
sizes = box_size(boxes_torch)
maxs = sizes.max(dim=0)[0]
best_iou = 0
# TBPSA, PSO
for algo in ["TwoPointsDE", "TwoPointsDE", "TwoPointsDE"]:
_best_iou = 0
params = []
for axis in range(dim):
# TODO: find better initialization
anchor_init = self.get_anchor_init(boxes_torch)
p = ng.p.Array(init=np.asarray(anchor_init[axis]))
p.set_integer_casting()
# p.set_bounds(1, maxs[axis].item())
p.set_bounds(lower=1)
params.append(p)
instrum = ng.p.Instrumentation(*params)
optimizer = ng.optimizers.registry[algo](
parametrization=instrum, budget=5000, num_workers=1)
with torch.no_grad():
pbar = tqdm(range(optimizer.budget), f"Anchor Opt {algo}")
for _ in pbar:
x = optimizer.ask()
anchors = anchor_generator.generate_anchors(*x.args)
anchors = compute_anchors_for_strides(
anchors, strides=strides, cat=True)
anchors = anchors
# TODO: add checks if GPU is availabe and has enough VRAM
iou = box_iou(boxes_torch.cuda(), anchors.cuda()) # boxes x anchors
mean_iou = iou.max(dim=1)[0].mean().cpu()
optimizer.tell(x, -mean_iou.item())
pbar.set_postfix(mean_iou=mean_iou)
_best_iou = mean_iou
if _best_iou > best_iou:
best_iou = _best_iou
recommendation = optimizer.provide_recommendation().value[0]
return {key: list(val) for key, val in zip(["width", "height", "depth"], recommendation)}
def get_anchor_init(self, boxes: torch.Tensor) -> Sequence[Sequence[int]]:
"""
Initialize anchors sizes for optimization
Args:
boxes: scales and transposed boxes
Returns:
Sequence[Sequence[int]]: anchor initialization
"""
return [(2, 4, 8)] * 3
def _plan_architecture(self,
target_spacing_transposed: Sequence[float],
target_median_shape_transposed: Sequence[float],
**kwargs,
) -> Sequence[int]:
"""
Plan patchsize and main aspects of the architecture
Fills entries in :param:`self.architecture_kwargs`:
`conv_kernels`
`strides`
`decoder_levels`
Args:
target_spacing_transposed: spacing after data is transposed and resampled
target_median_shape_transposed: median shape after data is
transposed and resampled
Returns:
Sequence[int]: patch size to use for training
"""
self.estimator.batch_size = self.batch_size
patch_size = np.asarray(self._get_initial_patch_size(
target_spacing_transposed, target_median_shape_transposed))
first_run = True
while True:
if first_run:
pass
else:
patch_size = self._decrease_patch_size(
patch_size, target_median_shape_transposed, pooling, must_be_divisible_by)
num_pool_per_axis, pooling, convs, patch_size, must_be_divisible_by = \
self.plan_pool_and_conv_pool_late(patch_size, target_spacing_transposed)
self.architecture_kwargs["conv_kernels"] = convs
self.architecture_kwargs["strides"] = pooling
num_resolutions = len(self.architecture_kwargs["conv_kernels"])
decoder_levels_start = min(max(0, num_resolutions - self.num_decoder_level), self.min_decoder_level)
self.architecture_kwargs["decoder_levels"] = \
tuple([i for i in range(decoder_levels_start, num_resolutions)])
print(self.architecture_kwargs["decoder_levels"])
print(self.get_anchors_for_estimation())
_, fits_in_mem = self.estimator.estimate(
min_shape=must_be_divisible_by,
target_shape=patch_size,
in_channels=self.architecture_kwargs["in_channels"],
network=self.network_cls.from_config_plan(
model_cfg=self.model_cfg,
plan_arch=self.architecture_kwargs,
plan_anchors=self.get_anchors_for_estimation()),
optimizer_cls=torch.optim.Adam,
)
if fits_in_mem:
break
first_run = False
logger.info(f"decoder levels: {self.architecture_kwargs['decoder_levels']}; \n"
f"pooling strides: {self.architecture_kwargs['strides']}; \n"
f"kernel sizes: {self.architecture_kwargs['conv_kernels']}; \n"
f"patch size: {patch_size}; \n")
return patch_size
def _decrease_patch_size(self,
patch_size: np.ndarray,
target_median_shape_transposed: np.ndarray,
pooling: Sequence[Sequence[int]],
must_be_divisible_by: Sequence[int],
) -> np.ndarray:
"""
Decrease largest physical axis. If it larger than bottleneck size is
is decreased by the minimum value to be divisable by computed pooling
strides and will be halfed otherwise.
Args:
patch_size: current patch size
target_median_shape_transposed: median shape of dataset
correctly transposed
pooling: pooling kernels of network
must_be_divisible_by: necessary divisor per axis
Returns:
np.ndarray: new patch size
"""
argsrt = np.argsort(patch_size / target_median_shape_transposed)[::-1]
pool_fct_per_axis = np.prod(pooling, 0)
bottleneck_size_per_axis = patch_size / pool_fct_per_axis
reduction = []
for i in range(len(patch_size)):
if bottleneck_size_per_axis[i] > self.min_feature_map_size:
reduction.append(must_be_divisible_by[i])
else:
reduction.append(must_be_divisible_by[i] / 2)
patch_size[argsrt[0]] -= reduction[argsrt[0]]
return patch_size
@staticmethod
def _get_initial_patch_size(target_spacing_transposed: np.ndarray,
target_median_shape_transposed: Sequence[int]) -> List[int]:
"""
Generate initial patch which relies on the spacing of underlying images.
This is based on the fact that most acquisition protocols are optimized
to focus on the most importatnt aspects.
Returns:
List[int]: initial patch size
"""
voxels_per_mm = 1 / np.array(target_spacing_transposed)
# normalize voxels per mm
input_patch_size = voxels_per_mm / voxels_per_mm.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(np.int32)
# clip it to the median shape of the dataset because patches larger then that make not much sense
input_patch_size = [min(i, j) for i, j in zip(
input_patch_size, target_median_shape_transposed)]
return np.round(input_patch_size).astype(np.int32)
def plan_pool_and_conv_pool_late(self,
patch_size: Sequence[int],
spacing: Sequence[float],
) -> Tuple[List[int], List[Tuple[int]], List[Tuple[int]],
Sequence[int], Sequence[int]]:
"""
Plan pooling and convolutions of encoder network
Axis which do not need pooling in every block are pooled as late as possible
Uses kernel size 1 for anisotropic axis which are not reached by the fov yet
Args:
patch_size: target path size
spacing: target spacing transposed
Returns:
List[int]: max number of pooling operations per axis
List[Tuple[int]]: kernel sizes of pooling operations
List[Tuple[int]]: kernel sizes of convolution layers
Sequence[int]: patch size
Sequence[int]: coefficient each axes needs to be divisable by
"""
num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, \
patch_size, must_be_divisible_by = get_pool_and_conv_props(
spacing=spacing, patch_size=patch_size,
min_feature_map_size=self.min_feature_map_size,
max_numpool=self.max_num_pool)
return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by
def get_anchors_for_estimation(self):
"""
Adjust anchor plan for varying number of feature maps
Returns:
dict: adjusted anchor plan
"""
num_levels = len(self.architecture_kwargs["decoder_levels"])
anchor_plan = {"stride": 1, "aspect_ratios": (0.5, 1, 2)}
if self.dim == 2:
_sizes = [(16, 32, 64)] * num_levels
anchor_plan["sizes"] = tuple(_sizes)
else:
_sizes = [(16, 32, 64)] * num_levels
anchor_plan["sizes"] = tuple(_sizes)
anchor_plan["zsizes"] = tuple(_sizes)
return anchor_plan
import os
import copy
from typing import Callable, Sequence, List
import torch
import numpy as np
from loguru import logger
from nndet.planning.estimator import MemoryEstimator, MemoryEstimatorDetection
from nndet.planning.architecture.boxes.base import BoxC001
from nndet.planning.architecture.boxes.utils import (
proxy_num_boxes_in_patch,
scale_with_abs_strides,
)
from nndet.core.boxes import (
get_anchor_generator,
expand_to_boxes,
box_center,
box_size_np,
permute_boxes,
)
class BoxC002(BoxC001):
def __init__(self,
preprocessed_output_dir: os.PathLike,
save_dir: os.PathLike,
network_cls: Callable,
estimator: MemoryEstimator = MemoryEstimatorDetection(),
model_cfg: dict = None,
**kwargs,
):
super().__init__(
preprocessed_output_dir=preprocessed_output_dir,
save_dir=save_dir,
network_cls=network_cls,
estimator=estimator,
model_cfg=model_cfg,
**kwargs
)
def create_default_settings(self):
"""
Generate default settings for the architecture
"""
super().create_default_settings()
self.architecture_kwargs["start_channels"] = 48 if self.dim == 2 else 32
self.architecture_kwargs["fpn_channels"] = \
self.architecture_kwargs["start_channels"] * 4
self.architecture_kwargs["head_channels"] = \
self.architecture_kwargs["fpn_channels"]
self.batch_size = 16 if self.dim == 2 else 4
self.min_feature_map_size = 8 if self.dim == 2 else 4
self.num_decoder_level = 5 if self.dim == 2 else 4
def get_anchor_init(self, boxes: torch.Tensor) -> Sequence[Sequence[int]]:
"""
Initialize anchors sizes for optimization
Args:
boxes: scales and transposed boxes
Returns:
Sequence[Sequence[int]]: anchor initialization
"""
box_dim = int(boxes.shape[1]) // 2
return [(4, 8, 16), ] * box_dim
def process_properties(self, **kwargs):
"""
Load dataset properties and extract information
"""
logger.info("Processing dataset properties")
self.all_boxes = [case["boxes"] for case_id, case
in self.dataset_properties["instance_props_per_patient"].items()]
self.all_spacings = [case["original_spacing"] for case_id, case
in self.dataset_properties["instance_props_per_patient"].items()]
self.num_instances_per_case = {case_id: sum(case["num_instances"].values())
for case_id, case in self.dataset_properties["instance_props_per_patient"].items()}
self.all_ious = self.dataset_properties["all_ious"]
self.class_ious = self.dataset_properties["class_ious"]
self.num_instances = self.dataset_properties["num_instances"]
self.dim = self.dataset_properties["dim"]
self.architecture_kwargs["classifier_classes"] = \
len(self.dataset_properties["class_dct"])
self.architecture_kwargs["seg_classes"] = \
self.architecture_kwargs["classifier_classes"]
self.architecture_kwargs["in_channels"] = \
len(self.dataset_properties["modalities"])
self.architecture_kwargs["dim"] = \
self.dataset_properties["dim"]
def plan(self,
target_spacing_transposed: Sequence[float],
median_shape_transposed: Sequence[float],
transpose_forward: Sequence[int],
mode: str = '3d',
) -> dict:
"""
Plan network architecture, anchors, patch size and batch size
Args:
target_spacing_transposed: spacing after data is transposed and resampled
median_shape_transposed: median shape after data is
transposed and resampled
transpose_forward: new ordering of axes for forward pass
mode: mode to use for planning ('3d' | '2d')
Returns:
dict: training and architecture information
See Also:
:method:`_plan_architecture`, :method:`_plan_anchors`
"""
if mode == "2d":
logger.info("Running 2d mode")
self.process_properties()
kwargs_2d = self.activate_2d_mode(
transpose_forward=transpose_forward,
target_spacing_transposed=target_spacing_transposed,
median_shape_transposed=median_shape_transposed,
)
res = super().plan(**kwargs_2d)
else:
res = super().plan(
transpose_forward=transpose_forward,
target_spacing_transposed=target_spacing_transposed,
median_shape_transposed=median_shape_transposed,
)
return res
def activate_2d_mode(self,
target_spacing_transposed: Sequence[float],
median_shape_transposed: Sequence[float],
transpose_forward: Sequence[int],
) -> dict:
target_spacing_transposed = target_spacing_transposed[1:]
median_shape_transposed = median_shape_transposed[1:]
keep = copy.copy(transpose_forward[1:])
transpose_forward = [t - 1 for t in keep]
keep_box = [0, 0, 0, 0]
for idx, k in enumerate(keep):
if k < 2:
keep_box[idx] = k
keep_box[idx + 2] = k + 2
else:
keep_box[idx] = 2 * k
keep_box[idx + 2] = 2 * k + 1
self.all_boxes = [b[:, keep_box] if (not isinstance(b, list) and b.shape[1] == 6) else b
for b in self.all_boxes]
self.all_spacings = [c[keep] if len(c) == 3 else c for c in self.all_spacings]
self.dim = 2
self.architecture_kwargs["dim"] = self.dim
return {
"target_spacing_transposed": target_spacing_transposed,
"median_shape_transposed": median_shape_transposed,
"transpose_forward": transpose_forward,
}
def _plan_architecture(self,
target_spacing_transposed: Sequence[float],
target_median_shape_transposed: Sequence[float],
transpose_forward: Sequence[int],
**kwargs,
) -> Sequence[int]:
"""
Plan patch size and main aspects of the architecture
Fills entries in :param:`self.architecture_kwargs`:
`conv_kernels`
`strides`
`decoder_levels`
Args:
target_spacing_transposed: spacing after data is transposed and resampled
target_median_shape_transposed: median shape after data is
transposed and resampled
Returns:
Sequence[int]: patch size to use for training
"""
self.estimator.batch_size = self.batch_size
patch_size = np.asarray(self._get_initial_patch_size(
target_spacing_transposed, target_median_shape_transposed))
first_run = True
while True:
if first_run:
pass
else:
patch_size = self._decrease_patch_size(
patch_size, target_median_shape_transposed, pooling, must_be_divisible_by)
num_pool_per_axis, pooling, convs, patch_size, must_be_divisible_by = \
self.plan_pool_and_conv_pool_late(patch_size, target_spacing_transposed)
self.architecture_kwargs["conv_kernels"] = convs
self.architecture_kwargs["strides"] = pooling
num_resolutions = len(self.architecture_kwargs["conv_kernels"])
decoder_levels_start = min(max(1, num_resolutions - self.num_decoder_level), self.min_decoder_level)
self.architecture_kwargs["decoder_levels"] = \
tuple([i for i in range(decoder_levels_start, num_resolutions)])
_, fits_in_mem = self.estimator.estimate(
min_shape=must_be_divisible_by,
target_shape=patch_size,
in_channels=self.architecture_kwargs["in_channels"],
network=self.network_cls.from_config_plan(
model_cfg=self.model_cfg,
plan_arch=self.architecture_kwargs,
plan_anchors=self.get_anchors_for_estimation()),
optimizer_cls=torch.optim.Adam,
num_instances=self._estimte_num_instances_per_patch(
patch_size=patch_size,
target_spacing_transposed=target_spacing_transposed,
transpose_forward=transpose_forward,
),
)
if fits_in_mem:
break
first_run = False
logger.info(f"decoder levels: {self.architecture_kwargs['decoder_levels']}; \n"
f"pooling strides: {self.architecture_kwargs['strides']}; \n"
f"kernel sizes: {self.architecture_kwargs['conv_kernels']}; \n"
f"patch size: {patch_size}; \n")
return patch_size
def _estimte_num_instances_per_patch(self,
patch_size,
target_spacing_transposed,
transpose_forward,
) -> int:
max_instances_per_image = []
for boxes in self._get_scaled_boxes(
target_spacing_transposed=target_spacing_transposed,
transpose_forward=transpose_forward,
cat=False,
):
max_instances_per_image.append(
max(proxy_num_boxes_in_patch(torch.from_numpy(boxes), patch_size)).item())
return max(max_instances_per_image)
def _plan_anchors(self,
target_spacing_transposed: Sequence[float],
transpose_forward: Sequence[int],
**kwargs,
) -> dict:
"""
Optimize anchors
"""
boxes_np_full = self._get_scaled_boxes(
target_spacing_transposed=target_spacing_transposed,
transpose_forward=transpose_forward,
)
boxes_np = self.filter_boxes(boxes_np_full)
logger.info(f"Filtered {boxes_np_full.shape[0] - boxes_np.shape[0]} "
f"boxes, {boxes_np.shape[0]} boxes remaining for anchor "
"planning.")
boxes_torch = torch.from_numpy(boxes_np).float()
boxes_torch = boxes_torch - expand_to_boxes(box_center(boxes_torch))
anchor_generator = get_anchor_generator(self.dim, s_param=True)
rel_strides = self.architecture_kwargs["strides"]
filt_rel_strides = [[1] * self.dim, *rel_strides]
filt_rel_strides = [filt_rel_strides[i] for i in self.architecture_kwargs["decoder_levels"]]
strides = np.cumprod(filt_rel_strides, axis=0) / np.asarray(rel_strides[0])
params = self.find_anchors(boxes_torch, strides.astype(np.int32), anchor_generator)
scaled_params = {key: scale_with_abs_strides(item, strides, dim_idx) for dim_idx, (key, item) in enumerate(params.items())}
logger.info(f"Determined Anchors: {params}; Results in params: {scaled_params}")
self.anchors = scaled_params
self.anchors["stride"] = 1
return self.anchors
def _get_scaled_boxes(self,
target_spacing_transposed: Sequence[float],
transpose_forward: Sequence[int],
cat: bool = True,
) -> np.ndarray:
"""
training is conducted in preprocessed image space and thus
we need to scale the extracted boxes to compensate for resampling
"""
boxes_np_list = []
for spacing, boxes in zip(self.all_spacings, self.all_boxes):
if not isinstance(boxes, list) and boxes.size > 0:
spacing_transposed = np.asarray(spacing)[transpose_forward]
scaling_transposed = spacing_transposed / np.asarray(target_spacing_transposed)
boxes_transposed = permute_boxes(np.asarray(boxes), dims=transpose_forward)
boxes_np_list.append(boxes_transposed * expand_to_boxes(scaling_transposed))
if cat:
return np.concatenate(boxes_np_list).astype(np.float32)
else:
return boxes_np_list
@staticmethod
def _get_initial_patch_size(target_spacing_transposed: np.ndarray,
target_median_shape_transposed: Sequence[int],
) -> List[int]:
"""
Generate initial patch which relies on the spacing of underlying images.
This is based on the fact that most acquisition protocols are optimized
to focus on the most importatnt aspects.
Returns:
List[int]: initial patch size
"""
voxels_per_mm = 1 / np.array(target_spacing_transposed)
# normalize voxels per mm
input_patch_size = voxels_per_mm / voxels_per_mm.mean()
# create an isotropic patch of size 512x512x512mm
input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
input_patch_size = np.round(input_patch_size).astype(np.int32)
# clip it to the median shape of the dataset because patches larger
# then that make not much sense and account for recangular patches
if len(target_spacing_transposed) > 2:
lowres_axis = np.argmax(target_spacing_transposed)
isotropic_axes = list(range(len(target_median_shape_transposed)))
isotropic_axes.pop(lowres_axis)
min_isotropic_axes_shape = min([target_median_shape_transposed[t] for t in isotropic_axes])
lowres_shape = target_median_shape_transposed[lowres_axis]
else:
lowres_axis = -1
lowres_shape = None
min_isotropic_axes_shape = min(target_median_shape_transposed)
initial_patch_size = []
for i in range(len(target_median_shape_transposed)):
if i == lowres_axis:
assert lowres_shape is not None
initial_patch_size.append(min(input_patch_size[i], lowres_shape))
else:
initial_patch_size.append(min(input_patch_size[i], min_isotropic_axes_shape))
initial_patch_size = np.round(initial_patch_size).astype(np.int32)
logger.info(f"Using initial patch size: {initial_patch_size}")
return initial_patch_size
def plot_box_distribution(self,
target_spacing_transposed: Sequence[float],
transpose_forward: Sequence[int],
**kwargs):
"""
Plot histogram with ground truth bounding box distribution for
all axis
"""
super().plot_box_distribution()
try:
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
import matplotlib.pyplot as plt
except ImportError:
logger.error("Failed to import matplotlib continue anyway.")
plt = None
if plt is not None:
if isinstance(self.all_boxes, list):
_boxes = np.concatenate(
[b for b in self.all_boxes if not isinstance(b, list) and b.size > 0], axis=0)
dists = box_size_np(_boxes)
else:
dists = box_size_np(self.all_boxes)
if dists.shape[1] == 3:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dists[:, 0], dists[:, 1], dists[:, 2])
ax.set_title(f"Transpose forward {transpose_forward}")
plt.savefig(self.save_dir / f'bbox_sizes_3d_orig.png')
plt.close()
dists = box_size_np(self._get_scaled_boxes(
target_spacing_transposed, transpose_forward))
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dists[:, 0], dists[:, 1], dists[:, 2])
plt.savefig(self.save_dir / f'bbox_sizes_3d.png')
plt.close()
else:
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(dists[:, 0], dists[:, 1])
ax.grid(True)
ax.set_title(f"Transpose forward {transpose_forward}")
plt.savefig(self.save_dir / f'bbox_sizes_2d_orig.png')
plt.close()
dists = box_size_np(self._get_scaled_boxes(
target_spacing_transposed, transpose_forward))
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(dists[:, 0], dists[:, 1])
ax.grid(True)
plt.savefig(self.save_dir / f'bbox_sizes_2d.png')
plt.close()
from typing import Sequence, List, Union, Tuple
import torch
import numpy as np
from torch import Tensor
from nndet.core.boxes import box_center
def scale_with_abs_strides(seq: Sequence[float],
strides: Sequence[Union[Sequence[Union[int, float]], Union[int, float]]],
dim_idx: int,
) -> List[Tuple[float]]:
"""
Scale values with absolute stride between feature maps
Args:
seq: sequence to scale
strides: strides to scale with.
dim_idx: dimension index for stride
"""
scaled = []
for stride in strides:
if not isinstance(stride, (float, int)):
_stride = stride[dim_idx]
else:
_stride = stride
_scaled = [i * _stride for i in seq]
scaled.append(tuple(_scaled))
return scaled
def proxy_num_boxes_in_patch(boxes: Tensor, patch_size: Sequence[int]) -> Tensor:
"""
This is just a proxy and not the exact computation
Args:
boxes: boxes
patch_size: patch size
Returns:
Tensor: count of boxes which center point is in the range of patch_size / 2
"""
patch_size = torch.tensor(patch_size, dtype=torch.float)[None, None] / 2 # [1, 1, dims]
center = box_center(boxes) # [N, dims]
center_dists = (center[None] - center[:, None]).abs() # [N, N, dims]
center_in_range = (center_dists <= patch_size).prod(dim=-1) # [N, N]
return center_in_range.sum(dim=1) # [N]
def comp_num_pool_per_axis(patch_size: Sequence[int],
max_num_pool: int,
min_feature_map_size: int) -> List[int]:
"""
Computes the maximum number of pooling operations given a minimal feature map size
and the patch size
Args:
patch_size: input patch size
max_num_pool: maximum number of pooling operations.
min_feature_map_size: Minimal size of feature map inside the bottleneck.
Returns:
List[int]: max number of pooling operations per axis
"""
network_numpool_per_axis = np.floor(
[np.log(i / min_feature_map_size) / np.log(2) for i in patch_size]).astype(np.int32)
network_numpool_per_axis = [min(i, max_num_pool) for i in network_numpool_per_axis]
return network_numpool_per_axis
def get_shape_must_be_divisible_by(num_pool_per_axis: Sequence[int]) -> np.ndarray:
"""
Returns a multiple of 2 which indicates by which factor an axis needs to
be dividable to avoid problems with upsampling
Args:
num_pool_per_axis: number of pooling operations per axis
Returns:
np.ndarray: necessary divisor of axis
"""
return 2 ** np.array(num_pool_per_axis)
def pad_shape(shape: Sequence[int], must_be_divisible_by: Sequence[int]) -> np.ndarray:
"""
Pads shape so that it is divisibly by must_be_divisible_by
Args:
shape: shape to pad
must_be_divisible_by: divisor
Returns:
np.ndarray: padded shape
"""
if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)):
must_be_divisible_by = [must_be_divisible_by] * len(shape)
else:
assert len(must_be_divisible_by) == len(shape)
new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i]
for i in range(len(shape))]
for i in range(len(shape)):
if shape[i] % must_be_divisible_by[i] == 0:
new_shp[i] -= must_be_divisible_by[i]
new_shp = np.array(new_shp).astype(np.int32)
return new_shp
def scale_with_abs_strides(seq: Sequence[float],
strides: Sequence[Union[Sequence[Union[int, float]], Union[int, float]]],
dim_idx: int,
) -> List[Tuple[float]]:
"""
Scale values with absolute stride between feature maps
Args:
seq: sequence to scale
strides: strides to scale with.
dim_idx: dimension index for stride
"""
scaled = []
for stride in strides:
if not isinstance(stride, (float, int)):
_stride = stride[dim_idx]
else:
_stride = stride
_scaled = [i * _stride for i in seq]
scaled.append(tuple(_scaled))
return scaled
def proxy_num_boxes_in_patch(boxes: Tensor, patch_size: Sequence[int]) -> Tensor:
"""
This is just a proxy and not the exact computation
Args:
boxes: boxes
patch_size: patch size
Returns:
Tensor: count of boxes which center point is in the range of patch_size / 2
"""
patch_size = torch.tensor(patch_size, dtype=torch.float)[None, None] / 2 # [1, 1, dims]
center = box_center(boxes) # [N, dims]
center_dists = (center[None] - center[:, None]).abs() # [N, N, dims]
center_in_range = (center_dists <= patch_size).prod(dim=-1) # [N, N]
return center_in_range.sum(dim=1) # [N]
def fixed_anchor_init(dim: int):
"""
Fixed anchors sizes for 2d and 3d
Args:
dim: number of dimensions
Returns:
dict: fixed params
"""
anchor_plan = {"stride": 1, "aspect_ratios": (0.5, 1, 2)}
if dim == 2:
anchor_plan["sizes"] = (32, 64, 128, 256)
else:
anchor_plan["sizes"] = ((4, 8, 16), (8, 16, 32), (16, 32, 64), (32, 64, 128))
anchor_plan["zsizes"] = ((2, 3, 4), (4, 6, 8), (8, 12, 16), (12, 24, 48))
return anchor_plan
def concatenate_property_boxes(all_boxes: Sequence[np.ndarray]) -> np.ndarray:
return np.concatenate([b for b in all_boxes if not isinstance(b, list) and b.size > 0], axis=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment