Commit 131a40e9 authored by mibaumgartner's avatar mibaumgartner
Browse files

fix imports

parent c61a4ce0
......@@ -228,8 +228,8 @@ Eachs of this commands is explained below and more detailt information can be ob
### Planning & Preprocessing
Before training the networks, nnDetection needs to preprocess and analyze the data.
The preprocessing stage noramlizaes and resamples the data while the analyzed properties are used to create a plan which will be used for configuring the training.
nnDetectionV0 requires a GPU with approximately the same amount of VRAM you are planning to use for training (i.e. we used a completely freed RTX2080TI) to perform live estimation of the VRAM used by the network.
Future releases will improve this process...
nnDetectionV0 requires a GPU with approximately the same amount of VRAM you are planning to use for training (i.e. we used a RTX2080TI; no monitor attached to it) to perform live estimation of the VRAM used by the network.
Future releases aim at improving this process...
```bash
nndet_prep [tasks] [-o / --overwrites]
......@@ -263,12 +263,11 @@ After planning and preprocessing the resulting data folder structure should look
Befor starting the training copy the data (Task Folder, dataset info and preprocessed folder are needed) to a SSD (highly recommended) and unpack the image data with
TODO: update name after reafactoring planner name
```bash
nndet_unpack [path] [num_processes]
# Example (unpack example with 6 processes)
nndet_unpack ${det_data}/Task000D3_Example/preprocessed/D3C002_3d/imagesTr 6
nndet_unpack ${det_data}/Task000D3_Example/preprocessed/D3V001_3d/imagesTr 6
# Script
# /experiments/utils.py - unpack()
......
from nndet.arch.decoder.fpn import FPN, UFPN, FPN2
from nndet.arch.decoder.base import BaseUFPN, UFPNModular, PAUFPN
......@@ -24,7 +24,6 @@ from abc import abstractmethod
from loguru import logger
from nndet.losses.classification import (
AsymmetricFocalLossWithLogits,
FocalLossWithLogits,
BCEWithLogitsLossOneHot,
CrossEntropyLoss,
......@@ -428,119 +427,4 @@ class FocalClassifier(BaseClassifier):
self.logits_convert_fn = nn.Sigmoid()
class AsymmetricFocalClassifier(FocalClassifier):
def __init__(self,
conv,
in_channels: int,
internal_channels: int,
num_classes: int,
anchors_per_pos: int,
num_levels: int,
num_convs: int = 3,
add_norm: bool = True,
prior_prob: Optional[float] = None,
gamma: float = 2,
alpha: float = -1,
reduction: str = "sum",
loss_weight: float = 1.,
**kwargs
):
"""
Classifier Head with sigmoid based BCE loss computation and
prio prob weight init
conv(in, internal) -> num_convs x conv(internal, internal) ->
conv(internal, out)
Args:
conv: Convolution modules which handles a single layer
in_channels: number of input channels
internal_channels: number of channels internally used
num_classes: number of foreground classes
anchors_per_pos: number of anchors per position
num_levels: number of decoder levels which are passed through the
classifier
num_convs: number of convolutions
input_conv -> num_convs -> output_convs
add_norm: en-/disable normalization layers in internal layers
prior_prob: initialize final conv with given prior probability
gamma: focal loss gamma
alpha: focal loss alpha
reduction: reduction to apply to loss. 'sum' | 'mean' | 'none'
loss_weight: scalar to balance multiple losses
kwargs: keyword arguments passed to first and internal convolutions
"""
self.prior_prob = prior_prob
super().__init__(
conv=conv,
in_channels=in_channels,
num_convs=num_convs,
add_norm=add_norm,
internal_channels=internal_channels,
num_classes=num_classes,
anchors_per_pos=anchors_per_pos,
num_levels=num_levels,
**kwargs,
)
self.loss = AsymmetricFocalLossWithLogits(
gamma=gamma,
alpha=alpha,
reduction=reduction,
loss_weight=loss_weight,
)
self.logits_convert_fn = nn.Sigmoid()
class FullyConntectedBCECLassifier(BCECLassifier):
"""
BCE Classifier with 1x1 convs which act as fc
layers with shared weights across spatial locations
conv3(in, internal) -> num_convs x conv1(internal, internal) -> conv1(internal, out)
"""
def build_conv_internal(self, conv, **kwargs):
"""
Build internal convolutions
"""
_conv_internal = nn.Sequential()
_conv_internal.add_module(
name="c_in",
module=conv(
self.in_channels,
self.internal_channels,
kernel_size=3,
stride=1,
padding=1,
**kwargs,
))
for i in range(self.num_convs):
_conv_internal.add_module(
name=f"c_internal{i}",
module=conv(
self.internal_channels,
self.internal_channels,
kernel_size=1,
stride=1,
padding=0,
**kwargs,
))
return _conv_internal
def build_conv_out(self, conv):
"""
Build final convolutions
"""
out_channels = self.num_classes * self.anchors_per_pos
return conv(
self.internal_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
add_norm=False,
add_act=False,
bias=True,
)
ClassifierType = TypeVar('ClassifierType', bound=Classifier)
......@@ -21,8 +21,8 @@ from torch import Tensor
from typing import Dict, List, Tuple, Optional, TypeVar
from abc import abstractmethod
from nndet.detection.boxes import BoxCoderND
from nndet.detection.boxes.sampler import AbstractSampler
from nndet.core.boxes import BoxCoderND
from nndet.core.boxes.sampler import AbstractSampler
from nndet.arch.heads.classifier import Classifier
from nndet.arch.heads.regressor import Regressor
......@@ -527,287 +527,4 @@ class DetectionHeadHNMRegAll(DetectionHeadHNM):
return losses, sampled_pos_inds, sampled_neg_inds
class BoxHeadNativeEMA(DetectionHeadHNMNative):
def __init__(self,
classifier: Classifier,
regressor: Regressor,
coder: BoxCoderND,
sampler: AbstractSampler,
log_num_anchors: Optional[str] = "mllogger",
beta: float = 0.9,
bias_correction: bool = True,
):
"""
Detection head with classifier and regression module. Uses hard negative
example mining to compute loss
Args:
classifier: classifier module
regressor: regression module
sampler: sampler for select positive and negative examples
log_num_anchors: name of logger to use; if None, no logging will be performed
beta: beta parameter for exponential morving average
bias_correction: use bias correction for exponential moving average
Warnings:
classifier loss is normalized in head -> use "sum" aggregation
for classifier
# WARNING: the current imeplementation does not normalize over the number of classes
"""
super().__init__(classifier=classifier, regressor=regressor,
coder=coder, sampler=sampler, log_num_anchors=log_num_anchors)
self.num_inds_cache = EMA(beta=beta, bias_correction=bias_correction)
self.num_pos_inds_cache = EMA(beta=beta, bias_correction=bias_correction)
def compute_loss(self,
prediction: Dict[str, Tensor],
target_labels: List[Tensor],
matched_gt_boxes: List[Tensor],
anchors: List[Tensor],
) -> Tuple[Dict[str, Tensor], torch.Tensor, torch.Tensor]:
"""
Compute regression and classification loss
N anchors over all images; M anchors per image => sum(M) = N
This head decodes the relative offsets from the networks and computes
the regression loss directly on the bounding boxes (e.g. for GIoU loss)
Args:
prediction: detection predictions for loss computation
box_logits (Tensor): classification logits for each anchor
[N, num_classes]
box_deltas (Tensor): offsets for each anchor
(x1, y1, x2, y2, (z1, z2))[N, dim * 2]
target_labels (List[Tensor]): target labels for each anchor
(per image) [M]
matched_gt_boxes: matched gt box for each anchor
List[[N, dim * 2]], N=number of anchors per image
anchors: anchors per image List[[N, dim * 2]]
Returns:
Tensor: dict with losses (reg for regression loss, cls
for classification loss)
Tensor: sampled positive indices of anchors (after concatenation)
Tensor: sampled negative indices of anchors (after concatenation)
"""
box_logits, box_deltas = prediction["box_logits"], prediction["box_deltas"]
losses = {}
sampled_pos_inds, sampled_neg_inds = self.select_indices(target_labels, box_logits)
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
target_labels = torch.cat(target_labels, dim=0)
batch_anchors = torch.cat(anchors, dim=0)
pred_boxes_sampled = self.coder.decode_single(
box_deltas[sampled_pos_inds], batch_anchors[sampled_pos_inds])
target_boxes_sampled = torch.cat(matched_gt_boxes, dim=0)[sampled_pos_inds]
assert len(batch_anchors) == len(box_deltas)
assert len(batch_anchors) == len(box_logits)
assert len(batch_anchors) == len(target_labels)
if sampled_pos_inds.numel() > 0:
self.num_pos_inds_cache.add(sampled_pos_inds.numel())
losses["reg"] = self.regressor.compute_loss(
pred_boxes_sampled,
target_boxes_sampled,
)
losses["reg"] = losses["reg"] / max(1, self.num_pos_inds_cache.get())
self.num_inds_cache.add(sampled_inds.numel())
losses["cls"] = self.classifier.compute_loss(
box_logits[sampled_inds], target_labels[sampled_inds])
losses["cls"] = losses["cls"] / self.num_inds_cache.get()
return losses, sampled_pos_inds, sampled_neg_inds
class EMA:
def __init__(self, beta: float = 0.9, bias_correction: bool = True):
"""
Exponentially weighted moving average
new_cache = beta * cache + (1 - beta) * new_val
Approximatley averages (1 - beta)^(-1) values
Args:
beta: weights for averaging
bias_correction: applies bias correction
"""
self.beta = beta
self.cache: float = 0.
self.bias_correction = bias_correction
self.t = 0
def add(self, val):
"""
Add new value
Args:
val: new value to add
"""
self.cache = self.beta * self.cache + (1 - self.beta) * float(val)
self.t = min(self.t + 1, 100000) # prevent overflow
def get(self) -> float:
"""
Retrive vurrent value
Returns:
float: current EMA
"""
if self.bias_correction and self.t > 0:
return (self.cache / (1 - pow(self.beta, self.t)))
else:
return self.cache
class BoxHNMNativeIoU(DetectionHeadHNMNative):
def __init__(self,
classifier: Classifier,
regressor: Regressor,
coder: BoxCoderND,
sampler: AbstractSampler,
log_num_anchors: Optional[str] = "mllogger",
iou_cutoff: float = 1.0,
):
"""
Args:
classifier: classifier module
regressor: regression module
sampler (AbstractSampler): sampler for select positive and
negative examples
log_num_anchors (str): name of logger to use; if None, no logging
will be performed
iou_cutoff: during inference the IoU wiill clipped to this
value and rescaled to one.
"""
super().__init__(
classifier=classifier,
regressor=regressor,
coder=coder,
sampler=sampler,
log_num_anchors=log_num_anchors,
)
self.iou_cutoff = iou_cutoff
def forward(self,
fmaps: List[torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Forward feature maps through head modules
Args:
fmaps: list of feature maps for head module
Returns:
Dict[str, torch.Tensor]: predictions
`box_deltas`(Tensor): bounding box offsets
[Num_Anchors_Batch, (dim * 2)]
`box_logits`(Tensor): classification logits
[Num_Anchors_Batch, (num_classes)]
"""
logits, offsets, ious_pred = [], [], []
for level, p in enumerate(fmaps):
logits.append(self.classifier(p, level=level))
box_offsets, box_iou_predicted = self.regressor(p, level=level)
offsets.append(box_offsets)
ious_pred.append(box_iou_predicted)
sdim = fmaps[0].ndim - 2
box_deltas = torch.cat(offsets, dim=1).reshape(-1, sdim * 2)
box_logits = torch.cat(logits, dim=1).flatten(0, -2)
box_ious_pred = torch.cat(ious_pred, dim=1).flatten()
return {"box_deltas": box_deltas,
"box_logits": box_logits,
"box_ious_pred": box_ious_pred,
}
def compute_loss(self,
prediction: Dict[str, Tensor],
target_labels: List[Tensor],
matched_gt_boxes: List[Tensor],
anchors: List[Tensor],
) -> Tuple[Dict[str, Tensor], torch.Tensor, torch.Tensor]:
"""
Compute regression and classification loss
N anchors over all images; M anchors per image => sum(M) = N
This head decodes the relative offsets from the networks and computes
the regression loss directly on the bounding boxes (e.g. for GIoU loss)
Args:
prediction: detection predictions for loss computation
box_logits (Tensor): classification logits for each anchor
[N, num_classes]
box_deltas (Tensor): offsets for each anchor
(x1, y1, x2, y2, (z1, z2))[N, dim * 2]
target_labels (List[Tensor]): target labels for each anchor
(per image) [M]
matched_gt_boxes: matched gt box for each anchor
List[[N, dim * 2]], N=number of anchors per image
anchors: anchors per image List[[N, dim * 2]]
Returns:
Tensor: dict with losses (reg for regression loss, cls for
classification loss)
Tensor: sampled positive indices of anchors (after concatenation)
Tensor: sampled negative indices of anchors (after concatenation)
"""
box_logits, box_deltas = prediction["box_logits"], prediction["box_deltas"]
box_ious_pred = prediction["box_ious_pred"]
losses = {}
sampled_pos_inds, sampled_neg_inds = self.select_indices(target_labels, box_logits)
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
target_labels = torch.cat(target_labels, dim=0)
batch_anchors = torch.cat(anchors, dim=0)
pred_boxes_sampled = self.coder.decode_single(
box_deltas[sampled_pos_inds], batch_anchors[sampled_pos_inds])
target_boxes_sampled = torch.cat(matched_gt_boxes, dim=0)[sampled_pos_inds]
assert len(batch_anchors) == len(box_deltas)
assert len(batch_anchors) == len(box_logits)
assert len(batch_anchors) == len(target_labels)
if sampled_pos_inds.numel() > 0:
losses["reg"] = self.regressor.compute_loss(
pred_boxes_sampled,
target_boxes_sampled,
box_ious_pred[sampled_pos_inds],
) / max(1, sampled_pos_inds.numel())
losses["cls"] = self.classifier.compute_loss(
box_logits[sampled_inds], target_labels[sampled_inds])
return losses, sampled_pos_inds, sampled_neg_inds
def postprocess_for_inference(self,
prediction: Dict[str, torch.Tensor],
anchors: List[torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""
Postprocess predictions for inference e.g. ocnvert logits to probs
Args:
Dict[str, torch.Tensor]: predictions from this head
`box_logits`: classification logits for each anchor [N]
`box_deltas`: offsets for each anchor
(x1, y1, x2, y2, (z1, z2))[N, dim * 2]
List[torch.Tensor]: anchors per image
"""
probs = self.classifier.box_logits_to_probs(prediction["box_logits"])
ious = torch.sigmoid(prediction["box_ious_pred"])
ious = ious.clamp_(min=0, max=self.iou_cutoff) / self.iou_cutoff
postprocess_predictions = {
"pred_boxes": self.coder.decode(prediction["box_deltas"], anchors),
"pred_probs": probs * ious[:, None],
}
return postprocess_predictions
HeadType = TypeVar('HeadType', bound=AbstractHead)
......@@ -22,7 +22,7 @@ from abc import abstractmethod
from loguru import logger
from nndet.detection.boxes import box_iou
from nndet.core.boxes import box_iou
from nndet.arch.layers.scale import Scale
from torch import Tensor
......@@ -310,123 +310,4 @@ class GIoURegressor(BaseRegressor):
)
class IoUBranchGIoURegressor(GIoURegressor):
def __init__(self,
conv,
in_channels: int,
internal_channels: int,
anchors_per_pos: int,
num_levels: int,
num_convs: int = 3,
add_norm: bool = True,
learn_scale: bool = False,
reduction: Optional[str] = "sum",
loss_weight: float = 1.,
loss_weight_iou_branch: float = 1.,
iou_fn: Callable[[Tensor, Tensor], Tensor] = box_iou,
**kwargs,
):
"""
GIoU Box regression head with additional IoU prediction branch
Args:
conv: Convolution modules which handles a single layer
in_channels: number of input channels
internal_channels: number of channels internally used
anchors_per_pos: number of anchors per position
num_levels: number of decoder levels which are passed through the
regressor
num_convs: number of convolutions
in conv -> num convs -> final conv
add_norm: en-/disable normalization layers in internal layers
learn_scale: learn additional single scalar values per feature
pyramid level
reduction: reduction to apply to loss. 'sum' | 'mean' | 'none'
loss_weight: scalar to balance multiple losses
loss_weight_iou_branch: weight of loss of IoU branch
iou_fn: iou function to compute targets for IoU branch
kwargs: keyword arguments passed to first and internal convolutions
"""
super().__init__(
conv=conv,
in_channels=in_channels,
internal_channels=internal_channels,
anchors_per_pos=anchors_per_pos,
num_levels=num_levels,
num_convs=num_convs,
add_norm=add_norm,
learn_scale=learn_scale,
reduction=reduction,
loss_weight=loss_weight,
**kwargs
)
self.conv_iou_branch = self.build_conv_iou_branch(conv)
self.iou_branch_loss = nn.BCEWithLogitsLoss()
self.loss_weight_iou_branch = loss_weight_iou_branch
self.iou_fn = iou_fn
def build_conv_iou_branch(self, conv) -> nn.Module:
"""
Build IoU branch convs
"""
return conv(
self.internal_channels,
self.anchors_per_pos,
kernel_size=3,
stride=1,
padding=1,
add_norm=False,
add_act=False,
bias=True,
)
def forward(self, x: torch.Tensor, level: int, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward input
Args:
x (torch.Tensor): input feature map of size [N x C x Y x X x Z]
Returns:
torch.Tensor: classification logits for each anchor [N, n_anchors, dim*2]
"""
intermediate_features = self.conv_internal(x)
bb_logits = self.conv_out(intermediate_features)
iou_logits = self.conv_iou_branch(intermediate_features)
if self.learn_scale:
bb_logits = self.scales[level](bb_logits)
axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
bb_logits = bb_logits.permute(*axes).contiguous()
bb_logits = bb_logits.view(x.size()[0], -1, self.dim * 2)
iou_logits = iou_logits.permute(*axes).contiguous()
iou_logits = iou_logits.view(x.size()[0], -1)
return bb_logits, iou_logits
def compute_loss(self,
pred_boxes: Tensor,
target_boxes: Tensor,
pred_iou: Tensor,
) -> Tensor:
"""
Compute regression loss and IoU branch loss
Args:
pred_boxes: predicted bounding box deltas [N, dim * 2]
target_boxes: target bounding box deltas [N, dim * 2]
pred_iou: predicted IoU
Returns:
Tensor: loss
"""
reg_loss = self.loss(pred_boxes, target_boxes)
target_ious = self.iou_fn(pred_boxes, target_boxes).diag(diagonal=0)
iou_branch_loss = self.loss_weight_iou_branch * self.iou_branch_loss(pred_iou, target_ious)
return reg_loss + iou_branch_loss
RegressorType = TypeVar('RegressorType', bound=Regressor)
......@@ -73,20 +73,16 @@ model_cfg:
head_classifier_kwargs: # keyword arguments passed to classifier in head
num_convs: 1
norm: "Group"
norm_kwargs:
channels_per_group: 16
affine: True
norm_channels_per_group: 16
norm_affine: True
reduction: "mean"
loss_weight: 1.
prior_prob: 0.01
head_regressor_kwargs: # keyword arguments passed to regressor in head
num_convs: 1
norm: "Group"
norm_kwargs:
channels_per_group: 16
affine: True
norm_channels_per_group: 16
norm_affine: True
reduction: "sum"
loss_weight: 1.
learn_scale: True
......
from nndet.core.boxes.anchors import get_anchor_generator, compute_anchors_for_strides, \
AnchorGenerator2D, AnchorGenerator2DS, AnchorGenerator3D, AnchorGenerator3DS
from nndet.core.boxes.clip import clip_boxes_to_image_, clip_boxes_to_image
from nndet.core.boxes.anchors import (
AnchorGeneratorType,
get_anchor_generator,
compute_anchors_for_strides,
AnchorGenerator2D,
AnchorGenerator2DS,
AnchorGenerator3D,
AnchorGenerator3DS,
)
from nndet.core.boxes.clip import (
clip_boxes_to_image_,
clip_boxes_to_image,
)
from nndet.core.boxes.coder import CoderType, BoxCoderND
from nndet.core.boxes.matcher import MatcherType, Matcher, IoUMatcher, ATSSMatcher
from nndet.core.boxes.nms import nms, batched_nms
......
import torch
from typing import Callable, Sequence, List, Tuple, Union
from typing import Callable, Sequence, List, Tuple, TypeVar, Union
from torchvision.models.detection.rpn import AnchorGenerator
from loguru import logger
from itertools import product
AnchorGeneratorType = TypeVar('AnchorGeneratorType', bound=AnchorGenerator)
def get_anchor_generator(dim: int, s_param: bool = False) -> AnchorGenerator:
"""
Get anchor generator class for corresponding dimension
......
......@@ -10,6 +10,7 @@ from nndet.arch.encoder.abstract import EncoderType
from nndet.arch.decoder.base import DecoderType
from nndet.arch.heads.segmenter import SegmenterType
from nndet.arch.heads.comb import HeadType
from nndet.core.boxes.anchors import AnchorGeneratorType
class BaseRetinaNet(AbstractModel):
......@@ -20,7 +21,7 @@ class BaseRetinaNet(AbstractModel):
decoder: DecoderType,
head: HeadType,
num_classes: int,
anchor_generator: box_utils.AnchorGenerator,
anchor_generator: AnchorGeneratorType,
matcher: box_utils.MatcherType,
decoder_levels: tuple = (2, 3, 4, 5),
# post-processing
......
......@@ -31,7 +31,6 @@ from nndet.inference.detection import batched_nms_model, batched_nms_ensemble, \
from nndet.inference.ensembler.base import BaseEnsembler, OverlapMap
from nndet.inference.restore import restore_detection
from nndet.core.boxes import box_center, clip_boxes_to_image, remove_small_boxes
from nndet.core.boxes.merging import GreedyIoUBoxMerger, VoteLabelGreedyIoUBoxMerger
from nndet.utils.tensor import cat, to_device, to_dtype
......@@ -1164,223 +1163,3 @@ class BoxEnsemblerSelective(BoxEnsembler):
self.model_results[model]["weights"] = weights[idx_sorted]
return super().save_state(target_dir=target_dir, name=name, **kwargs)
class BoxEnsemblerSelective2D(BoxEnsemblerSelective):
"""
Box ensembler for 2d predictions
Can be used to process 2d predictions of a 3d volume.
"""
@classmethod
def get_default_parameters(cls):
"""
Generate default parameters for instantiation
Returns:
Dict:
`model_iou`: IoU for model nms function
`model_nms_fn`: function to use for model NMS
`model_topk`: number of predictions with the highest
probability to keep
`ensemble_iou`: IoU for ensembling the predictions of multiple
models
`ensemble_nms_fn`: ensemble predictions from multiple
models
`ensemble_nms_topk`: number of predictions with the highest
probability to keep
`ensemble_remove_small_boxes`: minimum size of the box
`ensemble_score_thresh`: minimum probability
"""
params = super().get_default_parameters()
params["model_topk"] = 10000
params["model_detections_per_image"] = 4000
params["model_score_thresh"] = 0.3
params["ensemble_topk"] = 10000
params["track_iou"] = 0.5
params["track_neighbor_slices"] = 1
params["track_merger_cls"] = VoteLabelGreedyIoUBoxMerger
params["track_remove_small_boxes"] = 0
return params
@classmethod
def sweep_parameters(cls) -> Tuple[Dict[str, Any],
Dict[str, Sequence[Any]]]:
iou_threshs = np.linspace(0.0, 0.5, 6)
iou_threshs[0] = 1e-5
track_ious = np.linspace(0.3, 0.8, 6)
param_sweep = {
# single model
"model_iou": iou_threshs,
"model_nms_fn": [
batched_weighted_nms_model,
batched_nms_model,
],
# ensemble multiple models
"ensemble_iou": iou_threshs,
"track_iou": track_ious,
"track_neighbor_slices": [1, 2, 3, 4],
"track_merger_cls": [
GreedyIoUBoxMerger,
VoteLabelGreedyIoUBoxMerger,
],
"track_remove_small_boxes": [0, 1, 2, 3, 4],
"model_score_thresh": [0.2, 0.3, 0.4, 0.5, 0.6],
}
return cls.get_default_parameters(), param_sweep
@torch.no_grad()
def process_batch(self, result: Dict, batch: Dict):
"""
Process a single batch of bounding box predictions
(the boxes are clipped to the case size in the ensembling step)
This already expands the box into the positive z direction.
Expansion into the negative z dimension is done after the tracking
step.
Args:
result: prediction from detector. Need to provide boxes, scores
and class labels
`self.box_key`: List[Tensor]: predicted boxes (relative
to patch coordinates)
`self.score_key` List[Tensor]: score for each tensor
`self.label_key`: List[Tensor] label prediction for each box
batch: input batch for detector
`tile_origin: origin of crop with respect to actual data (
in case of padding)
`crop`: Sequence[slice] original crop from data
"""
slice_idx = [c.start for c in batch["crop"][0]]
boxes = [r.float().cpu() for r in result[self.box_key]]
scores = [r.float().cpu() for r in result[self.score_key]]
labels = [r.float().cpu() for r in result[self.label_key]]
# process 2d boxes
tile_size = batch[self.data_key].shape[2:]
centers = [box_center(img_boxes) if img_boxes.numel() > 0 else Tensor([]).to(img_boxes)
for img_boxes in boxes]
weights = [self._get_box_in_tile_weight(c, tile_size) for c in centers]
weights = [w * self.model_weights[self.model_current] for w in weights]
tile_origins = [to[1:] for to in zip(*batch["tile_origin"])]
boxes = self._apply_offsets_to_boxes(boxes, tile_origins)
# convert to 3d boxes
boxes_3d = []
for boxes_image, idx in zip(boxes, slice_idx):
if boxes_image.numel() > 0:
idx_tensor = torch.tensor([[float(idx), float(idx) + 1.]])
idx_tensor_expanded = idx_tensor.to(boxes_image).expand(boxes_image.shape[0], -1)
_boxes_3d = torch.stack([
idx_tensor_expanded[:, 0],
boxes_image[:, 0],
idx_tensor_expanded[:, 1],
boxes_image[:, 2],
boxes_image[:, 1],
boxes_image[:, 3],
], dim=1)
else:
_boxes_3d = boxes_image.view(-1, 6)
boxes_3d.append(_boxes_3d)
self.model_results[self.model_current]["boxes"].extend(boxes_3d)
self.model_results[self.model_current]["scores"].extend(scores)
self.model_results[self.model_current]["labels"].extend(labels)
self.model_results[self.model_current]["weights"].extend(weights)
@torch.no_grad()
def get_case_result(self,
restore: bool = False,
names: Optional[Sequence[Hashable]] = None,
) -> Dict[str, Tensor]:
"""
Process all the batches and models and create the final prediction
Args:
restore: restore prediction in the original image space
names: name of the models to use. By default all models are used.
Returns:
Dict: final result
`pred_boxes`: predicted box locations
[N, dims * 2] (x1, y1, x2, y2, (z1, z2))
`pred_scores`: predicted probability per box [N]
`pred_labels`: predicted label per box [N]
`restore`: indicate whether predictions were restored in
original image space
`original_size_of_raw_data`: image shape befor preprocessing
`itk_origin`: itk origin of image before preprocessing
`itk_spacing`: itk spacing of image before preprocessing
`itk_direction`: itk direction of image before preprocessing
"""
if names is None:
names = list(self.model_results.keys())
boxes, probs, labels, weights = [], [], [], []
for name in names:
_boxes, _probs, _labels, _weights = self.process_model(name)
boxes.append(_boxes)
probs.append(_probs)
labels.append(_labels)
weights.append(_weights)
boxes, probs, labels = self.process_ensemble(
boxes=boxes, probs=probs, labels=labels,
weights=weights,
)
boxes, probs, labels = self.track_2d_to_3d(boxes, probs, labels)
if restore:
boxes = self.restore_prediction(boxes)
return {
"pred_boxes": boxes,
"pred_scores": probs,
"pred_labels": labels,
"restore": restore,
"original_size_of_raw_data": self.properties["original_size_of_raw_data"],
"itk_origin": self.properties["itk_origin"],
"itk_spacing": self.properties["itk_spacing"],
"itk_direction": self.properties["itk_direction"],
}
def track_2d_to_3d(self, boxes: Tensor, probs: Tensor, labels: Tensor):
"""
Converts the 2d tubes to 3d boxes
Args:
boxes: pseudo 3d boxes (each 2d box is projected into 3d space)
probs: probability of each box
labels: label of each box
Returns:
boxes: 3d boxes
probs: probabilities
labels: labels
"""
if self.properties["shape"][0] > 1:
boxes_2d = boxes[:, [1, 4, 3, 5]] # [N, 4]
slices = torch.round(boxes[:, 0]).int() # [N]
merger = self.parameters["track_merger_cls"](
boxes=boxes_2d,
slices=slices,
scores=probs,
labels=labels,
iou_th=self.parameters["track_iou"],
neighbor_slices=self.parameters["track_neighbor_slices"],
)
boxes, probs, labels = merger.merge()
# the upper bound is already correct, we need to fix the lower bound here
boxes[:, 0] = boxes[:, 0] - 1
keep = remove_small_boxes(
boxes, min_size=self.parameters["track_remove_small_boxes"])
boxes, probs, labels = boxes[keep], probs[keep], labels[keep]
return boxes, probs, labels
......@@ -5,9 +5,5 @@ DATALOADER_REGISTRY: Mapping[str, Iterable] = Registry()
from nndet.io.datamodule.bg_loader import (
DataLoader3DFast,
DataLoader3DBalanced,
DataLoader3DOffset,
DataLoader2DOffset,
DataLoader2DFast,
DataLoader2DDeeplesion,
)
from nndet.planning.analyzer import DatasetAnalyzer
from nndet.planning.plan_architecture import (
ArchitecturePlanner,
FixedArchitecturePlanner,
DetectionPlanner,
FixedDetectionPlanner,
)
from nndet.planning.plan_experiment import (
AbstractPlanner,
D3V001,
)
from nndet.planning.experiment import PLANNER_REGISTRY
......@@ -219,7 +219,7 @@ class BaseBoxesPlanner(ArchitecturePlanner):
self.plot_instance_distribution(**kwargs)
return {}
def create_default_settings():
def create_default_settings(self):
pass
def compute_class_weights(self) -> List[float]:
......@@ -258,7 +258,7 @@ class BaseBoxesPlanner(ArchitecturePlanner):
return base
class DetectionPlanner(BaseBoxesPlanner):
class BoxC001(BaseBoxesPlanner):
def __init__(self,
preprocessed_output_dir: os.PathLike,
save_dir: os.PathLike,
......
......@@ -7,7 +7,7 @@ import numpy as np
from loguru import logger
from nndet.planning.estimator import MemoryEstimator, MemoryEstimatorDetection
from nndet.planning.architecture.boxes.base import BaseBoxesPlanner
from nndet.planning.architecture.boxes.base import BoxC001
from nndet.planning.architecture.boxes.utils import (
proxy_num_boxes_in_patch,
scale_with_abs_strides,
......@@ -21,7 +21,7 @@ from nndet.core.boxes import (
)
class BoxC002(BaseBoxesPlanner):
class BoxC002(BoxC001):
def __init__(self,
preprocessed_output_dir: os.PathLike,
save_dir: os.PathLike,
......
......@@ -26,7 +26,7 @@ from typing import Sequence, Union, Callable, Tuple
from contextlib import contextmanager
from loguru import logger
from nndet.models.abstract import AbstractModel
from nndet.arch.abstract import AbstractModel
"""
This is just a first prototype to estimate VRAM consumption for different GPUs
......
......@@ -26,7 +26,7 @@ from typing import Dict, Sequence, List, Tuple
from nndet.io.load import load_case_cropped
from nndet.planning import DatasetAnalyzer
from nndet.detection.boxes import box_iou_np
from nndet.core.boxes import box_iou_np
def analyze_instances(analyzer: DatasetAnalyzer) -> dict:
......
from nndet.preprocessing.crop import ImageCropper
from nndet.preprocessing.preprocessor import (
PreprocessorType,
AbstractPreprocessor,
GenericPreprocessor,
)
......@@ -60,7 +60,7 @@ from nndet.inference.transforms import get_tta_transforms, Inference2D
from nndet.inference.loading import load_final_model
from nndet.inference.helper import predict_dir
from nndet.inference.ensembler.segmentation import SegmentationEnsembler
from nndet.inference.ensembler.detection import BoxEnsemblerSelective, BoxEnsemblerSelective2D
from nndet.inference.ensembler.detection import BoxEnsemblerSelective
from rising.transforms import Compose
from nndet.io.transforms import Instances2Boxes, Instances2Segmentation, FindInstances
......@@ -650,14 +650,16 @@ class RetinaUNetModule(LightningBaseModuleSWA):
"""
_lookup = {
2: {
"boxes": BoxEnsemblerSelective2D,
"seg": SegmentationEnsembler,
"boxes": None,
"seg": None,
},
3: {
"boxes": BoxEnsemblerSelective,
"seg": SegmentationEnsembler,
}
}
if dim == 2:
raise NotImplementedError
return _lookup[dim][key]
@classmethod
......@@ -700,6 +702,7 @@ class RetinaUNetModule(LightningBaseModuleSWA):
**kwargs,
)
if plan["network_dim"] == 2:
raise NotImplementedError
predictor.pre_transform = Inference2D(["data"])
return predictor
......
......@@ -22,6 +22,7 @@ from hydra.experimental import initialize_config_module
from nndet.utils.config import compose
if __name__ == '__main__':
"""
Automatically deletes files generated by seg2det and restores
......
......@@ -24,7 +24,7 @@ from loguru import logger
from pathlib import Path
from nndet.utils.info import env_guard
from nndet.planning.plan_experiment import PLANNER_REGISTRY
from nndet.planning import PLANNER_REGISTRY
from nndet.io import get_task, get_training_dir
from nndet.io.load import load_pickle
from nndet.inference.loading import load_all_models
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment