Commit 94d6ac20 authored by mibaumgartner's avatar mibaumgartner
Browse files

core

parent 3e94607a
from nndet.core.boxes.anchors import get_anchor_generator, compute_anchors_for_strides, \
AnchorGenerator2D, AnchorGenerator2DS, AnchorGenerator3D, AnchorGenerator3DS
from nndet.core.boxes.clip import clip_boxes_to_image_, clip_boxes_to_image
from nndet.core.boxes.coder import BoxCoderND
from nndet.core.boxes.matcher import MatcherType, Matcher, IoUMatcher, ATSSMatcher
from nndet.core.boxes.nms import nms, batched_nms
from nndet.core.boxes.sampler import AbstractSampler, NegativeSampler, HardNegativeSampler, \
BalancedHardNegativeSampler, HardNegativeSamplerFgAll, HardNegativeSamplerBatched
from nndet.core.boxes.utils import box_area, box_iou, remove_small_boxes, box_center, permute_boxes, \
expand_to_boxes, box_size, generalized_box_iou, box_center_dist, center_in_boxes
from nndet.core.boxes.utils_np import box_iou_np, box_size_np, box_area_np
This diff is collapsed.
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Tuple
import torch
def clip_boxes_to_image_(boxes: torch.Tensor, img_shape: Tuple[int]):
"""
Clip boxes to image dimensions inplace
Args:
boxes (Tensor): tensor with boxes [N x (2*dim)] (x_min, y_min, x_max, y_max(, z_min, z_max))
img_shape (Tuple[height, width(, depth)]): size of image
Returns:
Tensor: clipped boxes as tensor
Raises:
ValueError: boxes need to have 4(2D) or 6(3D) components
"""
if boxes.shape[-1] == 4:
return clip_boxes_to_image_2d_(boxes, img_shape)
elif boxes.shape[-1] == 6:
return clip_boxes_to_image_3d_(boxes, img_shape)
else:
raise ValueError(f"Boxes with {boxes.shape[-1]} are not supported.")
def clip_boxes_to_image(boxes: torch.Tensor, img_shape: Tuple[int]):
"""
Clip boxes to image dimensions
Args:
boxes (Tensor): tensor with boxes [N x (2*dim)] (x_min, y_min, x_max, y_max(, z_min, z_max))
img_shape (Tuple[height, width(, depth)]): size of image
Returns:
Tensor: clipped boxes as tensor
Raises:
ValueError: boxes need to have 4(2D) or 6(3D) components
"""
if boxes.shape[-1] == 4:
return clip_boxes_to_image_2d(boxes, img_shape)
elif boxes.shape[-1] == 6:
return clip_boxes_to_image_3d(boxes, img_shape)
else:
raise ValueError(f"Boxes with {boxes.shape[-1]} are not supported.")
def clip_boxes_to_image_2d_(boxes: torch.Tensor, img_shape: Tuple[int, int]):
"""
Clip boxes to image dimensions
Args:
boxes (Tensor): tensor with boxes [N x 4] (x_min, y_min, x_max, y_max)
img_shape (Tuple[x_max, y_max]): size of image
Returns:
Tensor: clipped boxes as tensor
"""
s0, s1 = img_shape
boxes[..., 0::2].clamp_(min=0, max=s0)
boxes[..., 1::2].clamp_(min=0, max=s1)
return boxes
def clip_boxes_to_image_3d_(boxes: torch.Tensor, img_shape: Tuple[int, int, int]):
"""
Clip boxes to image dimensions
Args:
boxes (Tensor): tensor with boxes [N x 6] (x_min, y_min, x_max, y_max, z_min, z_max)
img_shape (Tuple[height, width, depth]): size of image
Returns:
Tensor: clipped boxes as tensor
"""
s0, s1, s2 = img_shape
boxes[..., 0::6].clamp_(min=0, max=s0)
boxes[..., 1::6].clamp_(min=0, max=s1)
boxes[..., 2::6].clamp_(min=0, max=s0)
boxes[..., 3::6].clamp_(min=0, max=s1)
boxes[..., 4::6].clamp_(min=0, max=s2)
boxes[..., 5::6].clamp_(min=0, max=s2)
return boxes
def clip_boxes_to_image_2d(boxes: torch.Tensor, img_shape: Tuple[int, int]):
"""
Clip boxes to image dimensions
Args:
boxes (Tensor): tensor with boxes [N x 4] (x_min, y_min, x_max, y_max)
img_shape (Tuple[x_max, y_max]): size of image
Returns:
Tensor: clipped boxes as tensor
Notes:
Uses float32 internally because clipping of half cpu tensors is not
supported
"""
s0, s1 = img_shape
boxes[..., 0::2] = boxes[..., 0::2].clamp(min=0, max=s0)
boxes[..., 1::2] = boxes[..., 1::2].clamp(min=0, max=s1)
return boxes
def clip_boxes_to_image_3d(boxes: torch.Tensor, img_shape: Tuple[int, int, int]):
"""
Clip boxes to image dimensions
Args:
boxes (Tensor): tensor with boxes [N x 6] (x_min, y_min, x_max, y_max, z_min, z_max)
img_shape (Tuple[height, width, depth]): size of image
Returns:
Tensor: clipped boxes as tensor
Notes:
Uses float32 internally because clipping of half cpu tensors is not
supported
"""
s0, s1, s2 = img_shape
boxes[..., 0::6] = boxes[..., 0::6].clamp(min=0, max=s0)
boxes[..., 1::6] = boxes[..., 1::6].clamp(min=0, max=s1)
boxes[..., 2::6] = boxes[..., 2::6].clamp(min=0, max=s0)
boxes[..., 3::6] = boxes[..., 3::6].clamp(min=0, max=s1)
boxes[..., 4::6] = boxes[..., 4::6].clamp(min=0, max=s2)
boxes[..., 5::6] = boxes[..., 5::6].clamp(min=0, max=s2)
return boxes
from __future__ import division
import math
from typing import Sequence
import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
from torchvision.models.detection._utils import BoxCoder
@torch.jit.script
def encode_boxes(reference_boxes: torch.Tensor,
proposals: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
"""
Encode a set of proposals with respect to some reference boxes
Args:
reference_boxes: reference boxes (x1, y1, x2, y2, (z1, z2))
proposals: boxes to be encoded (x1, y1, x2, y2, (z1, z2))
weights: weights for dimensions (wx, wy, ww, wh, wz, wd)
"""
# perform some unpacking to make it JIT-fusion friendly
wx = weights[0]
wy = weights[1]
ww = weights[2]
wh = weights[3]
proposals_x1 = proposals[:, 0].unsqueeze(1)
proposals_y1 = proposals[:, 1].unsqueeze(1)
proposals_x2 = proposals[:, 2].unsqueeze(1)
proposals_y2 = proposals[:, 3].unsqueeze(1)
reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
# implementation starts here
ex_widths = proposals_x2 - proposals_x1
ex_heights = proposals_y2 - proposals_y1
ex_ctr_x = proposals_x1 + 0.5 * ex_widths
ex_ctr_y = proposals_y1 + 0.5 * ex_heights
gt_widths = reference_boxes_x2 - reference_boxes_x1
gt_heights = reference_boxes_y2 - reference_boxes_y1
gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
targets_dw = ww * torch.log(gt_widths / ex_widths)
targets_dh = wh * torch.log(gt_heights / ex_heights)
if proposals.shape[1] == 6:
wz = weights[4]
wd = weights[5]
proposals_z1 = proposals[:, 4].unsqueeze(1)
proposals_z2 = proposals[:, 5].unsqueeze(1)
ex_depth = proposals_z2 - proposals_z1
ex_ctr_z = proposals_z1 + 0.5 * ex_depth
reference_boxes_z1 = reference_boxes[:, 4].unsqueeze(1)
reference_boxes_z2 = reference_boxes[:, 5].unsqueeze(1)
gt_depth = reference_boxes_z2 - reference_boxes_z1
gt_ctr_z = reference_boxes_z1 + 0.5 * gt_depth
targets_dz = wz * (gt_ctr_z - ex_ctr_z) / ex_depth
targets_dd = wd * torch.log(gt_depth / ex_depth)
targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh,
targets_dz, targets_dd), dim=1)
else:
targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
return targets
def decode_single(rel_codes: Tensor, boxes: Tensor,
weights: Sequence[float],
bbox_xform_clip: float) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Args:
rel_codes: encoded boxes [Num_boxes x (dim * 2)] (dx, dy, dw, dh, dz, dd)
boxes: reference boxes (x1, y1, x2, y2, (z1, z2))
"""
# offset is 4 in case of 2d data and 6 in case of 3d
offset = boxes.shape[1]
boxes = boxes.to(rel_codes.dtype)
widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1]
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
wx = weights[0]
wy = weights[1]
ww = weights[2]
wh = weights[3]
dx = rel_codes[:, 0::offset] / wx
dy = rel_codes[:, 1::offset] / wy
dw = rel_codes[:, 2::offset] / ww
dh = rel_codes[:, 3::offset] / wh
# Prevent sending too large values into torch.exp()
dw = torch.clamp(dw, max=bbox_xform_clip)
dh = torch.clamp(dh, max=bbox_xform_clip)
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]
pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
if offset == 6:
depths = boxes[:, 5] - boxes[:, 4]
ctr_z = boxes[:, 4] + 0.5 * depths
wz = weights[4]
wd = weights[5]
dz = rel_codes[:, 4::offset] / wz
dd = rel_codes[:, 5::offset] / wd
dd = torch.clamp(dd, max=bbox_xform_clip)
pred_ctr_z = dz * depths[:, None] + ctr_z[:, None]
pred_z = torch.exp(dd) * depths[:, None]
pred_boxes5 = pred_ctr_z - torch.tensor(0.5, dtype=pred_ctr_z.dtype) * pred_z
pred_boxes6 = pred_ctr_z + torch.tensor(0.5, dtype=pred_ctr_z.dtype) * pred_z
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4,
pred_boxes5, pred_boxes6), dim=2).flatten(1)
else:
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4),
dim=2).flatten(1)
return pred_boxes
class BoxCoderND(BoxCoder):
"""
This class encodes and decodes a set of bounding boxes into
the representation used for training the regressors.
Compatible with 2d and 3d
"""
def encode(self,
reference_boxes: List[Tensor],
proposals: List[Tensor],
) -> Tuple[Tensor]:
"""
Encode a set of proposals with respect to some reference boxes
Args:
reference_boxes: reference boxes for each image.
(x1, y1, x2, y2, (z1, z2))
proposals: proposals for each image
(x1, y1, x2, y2, (z1, z2))
Returns:
Tuple[Tensor]: regression targets for each image
"""
# filter for images which have a foreground class
filter_min_one_gt = [rb.numel() > 0 for rb in reference_boxes]
filtered_ref_boxes = [
rb for idx, rb in enumerate(reference_boxes) if filter_min_one_gt[idx]]
filtered_proposals = [
pr for idx, pr in enumerate(proposals) if filter_min_one_gt[idx]]
if any(filter_min_one_gt):
filtered_encoded = super().encode(filtered_ref_boxes, filtered_proposals)
# fill image with no ground truth
idx_enc = 0
encoded = []
for img_idx, gt_present in enumerate(filter_min_one_gt):
if gt_present:
encoded.append(filtered_encoded[idx_enc])
idx_enc += 1
else:
# fill with zeros because they do not contribute to the
# regression loss anyway (all anchors are labeled as background)
encoded.append(torch.zeros_like(proposals[img_idx]))
return encoded
def encode_single(self,
reference_boxes: Tensor,
proposals: Tensor,
) -> Tensor:
"""
Encode a set of proposals with respect to some reference boxes
Arguments:
reference_boxes: reference boxes (x1, y1, x2, y2, (z1, z2))
proposals: boxes to be encoded (x1, y1, x2, y2, (z1, z2))
"""
dtype, device = reference_boxes.dtype, reference_boxes.device
weights = torch.tensor(self.weights, dtype=dtype, device=device)
targets = encode_boxes(reference_boxes, proposals, weights)
return targets
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
"""
Decode boxes
Args:
rel_codes: relative offsets to reference boxes
(dx, dy, dw, dh, (dz, dd))[N, dim * 2]
boxes: list of reference boxes per image
(x1, y1, x2, y2, (z1, z2))
Returns:
Tensor: decoded boxes
"""
assert isinstance(boxes, (list, tuple))
assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [b.size(0) for b in boxes]
concat_boxes = torch.cat(boxes, dim=0)
spatial_dims = concat_boxes.shape[1]
box_sum = 0
for val in boxes_per_image:
box_sum += val
pred_boxes = self.decode_single(rel_codes.reshape(box_sum, -1), concat_boxes)
return pred_boxes.reshape(box_sum, spatial_dims)
def decode_single(self, rel_codes: torch.Tensor, boxes: torch.Tensor):
dtype, device = rel_codes.dtype, rel_codes.device
return decode_single(rel_codes, boxes, self.weights, self.bbox_xform_clip)
from typing import Sequence, Callable, Tuple, TypeVar
from abc import ABC
import torch
from torch import Tensor
from loguru import logger
from nndet.core.boxes.utils import box_iou, box_center_dist, center_in_boxes
INF = 100 # not really inv but here it is sufficient
class Matcher(ABC):
BELOW_LOW_THRESHOLD: int = -1
BETWEEN_THRESHOLDS: int = -2
def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou):
"""
Matches boxes and anchors to each other
Args:
similarity_fn: function for similarity computation between
boxes and anchors
"""
self.similarity_fn = similarity_fn
def __call__(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches for a single image
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
if boxes.numel() == 0:
# no ground truth
num_anchors = anchors.shape[0]
match_quality_matrix = torch.tensor([]).to(anchors)
matches = torch.empty(num_anchors, dtype=torch.int64).fill_(self.BELOW_LOW_THRESHOLD)
return match_quality_matrix, matches
else:
# at least one ground truth
return self.compute_matches(
boxes=boxes, anchors=anchors,
num_anchors_per_level=num_anchors_per_level,
num_anchors_per_loc=num_anchors_per_loc,
)
def compute_matches(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
raise NotImplementedError
class IoUMatcher(Matcher):
def __init__(self,
low_threshold: float,
high_threshold: float,
allow_low_quality_matches: bool,
similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou):
"""
Compute IoU based matching for a single image
Args:
low_threshold: threshold used to assign background values
high_threshold: threshold used to assign foreground values
allow_low_quality_matches: if enabled, anchors with not
match get the box with highest IoU assigned
similarity_fn: function for similarity computation between
boxes and anchors
"""
super().__init__(similarity_fn=similarity_fn)
assert low_threshold <= high_threshold
self.high_threshold = high_threshold
self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches
def compute_matches(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches according to given iou thresholds
Adapted from
(https://github.com/pytorch/vision/blob/c7c2085ec686ccc55e1df85736b240b24
05d1179/torchvision/models/detection/_utils.py)
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors_per_level: number of anchors per feature pyramid level
anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
match_quality_matrix = self.similarity_fn(boxes, anchors)
# match_quality_matrix is M (gt) x N (anchors)
# Max over gt elements (dim 0) to find best gt candidate for each anchor
matched_vals, matches = match_quality_matrix.max(dim=0)
# _v, _i = matched_vals.topk(5)
# print(boxes, _v, anchors[_i])
if self.allow_low_quality_matches:
all_matches = matches.clone()
# Assign candidate matches with low quality to negative (unassigned) values
below_low_threshold = matched_vals < self.low_threshold
between_thresholds = (matched_vals >= self.low_threshold) & (
matched_vals < self.high_threshold
)
matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
matches[between_thresholds] = self.BETWEEN_THRESHOLDS
if self.allow_low_quality_matches:
matches = self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
# self._debug_logging(match_quality_matrix, matches, matched_vals,
# below_low_threshold, between_thresholds)
return match_quality_matrix, matches
def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
"""
Find the best matching prediction for each bounding box
regardless of its IoU (this implementation excludes ties!)
Args:
matches: matched anchors to background and in between
all_matches: all matches regardless of IoU
match_quality_matrix: [M,N] tensor of IoUs (GroundTruth x NumAnchors)
"""
# For each gt, find the prediction with has highest quality
_, best_pred_idx = match_quality_matrix.max(dim=1) # [M]
matches[best_pred_idx] = torch.arange(len(best_pred_idx)).to(matches)
return matches
@staticmethod
def _debug_logging(match_quality_matrix, matches, matched_vals,
below_low_threshold, between_thresholds):
logger.info("########## Matcher ##############")
logger.info(f"Max IoU: {match_quality_matrix.max()}")
logger.info(f"Foreground IoUs: {matched_vals[matches > -1]}")
logger.info(f"Num GT: {match_quality_matrix.shape[0]}")
match_bet_min = matched_vals[between_thresholds].min() if \
matched_vals[between_thresholds].nelement() > 0 else None
match_bet_max = matched_vals[between_thresholds].max() if \
matched_vals[between_thresholds].nelement() > 0 else None
logger.info(f"Inbetween IoU ranging from {match_bet_min} to {match_bet_max}")
logger.info(f"Max background IoU: {matched_vals[below_low_threshold].max()}")
logger.info("#################################")
class ATSSMatcher(Matcher):
def __init__(self,
num_candidates: int,
similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou,
center_in_gt: bool = True,
):
"""
Compute matching based on ATSS
https://arxiv.org/abs/1912.02424
`Bridging the Gap Between Anchor-based and Anchor-free Detection
via Adaptive Training Sample Selection`
Args:
num_candidates: number of positions to select candidates from
similarity_fn: function for similarity computation between
boxes and anchors
center_in_gt: If diabled, matched anchor center points do not need
to lie withing the ground truth box.
"""
super().__init__(similarity_fn=similarity_fn)
self.num_candidates = num_candidates
self.min_dist = 0.01
self.center_in_gt = center_in_gt
logger.info(f"Running ATSS Matching with num_candidates={self.num_candidates} "
f"and center_in_gt {self.center_in_gt}.")
def compute_matches(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches according to ATTS for a single image
Adapted from
(https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss
/loss.py#L180-L184)
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
num_gt = boxes.shape[0]
num_anchors = anchors.shape[0]
distances, boxes_center, anchors_center = box_center_dist(boxes, anchors) # num_boxes x anchors
# select candidates based on center distance
candidate_idx = []
start_idx = 0
for level, apl in enumerate(num_anchors_per_level):
end_idx = start_idx + apl
topk = min(self.num_candidates * num_anchors_per_loc, apl)
_, idx = distances[:, start_idx: end_idx].topk(topk, dim=1, largest=False)
# idx shape [num_boxes x topk]
candidate_idx.append(idx + start_idx)
start_idx = end_idx
# [num_boxes x num_candidates] (index of candidate anchors)
candidate_idx = torch.cat(candidate_idx, dim=1)
match_quality_matrix = self.similarity_fn(boxes, anchors) # [num_boxes x anchors]
candidate_ious = match_quality_matrix.gather(1, candidate_idx) # [num_boxes, n_candidates]
# compute adaptive iou threshold
iou_mean_per_gt = candidate_ious.mean(dim=1) # [num_boxes]
iou_std_per_gt = candidate_ious.std(dim=1) # [num_boxes]
iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt # [num_boxes]
is_pos = candidate_ious >= iou_thresh_per_gt[:, None] # [num_boxes x n_candidates]
if self.center_in_gt: # can discard all candidates in case of very small objects :/
# center point of selected anchors needs to lie within the ground truth
boxes_idx = torch.arange(num_gt, device=boxes.device, dtype=torch.long)[:, None]\
.expand_as(candidate_idx).contiguous() # [num_boxes x n_candidates]
is_in_gt = center_in_boxes(
anchors_center[candidate_idx.view(-1)], boxes[boxes_idx.view(-1)], eps=self.min_dist)
is_pos = is_pos & is_in_gt.view_as(is_pos) # [num_boxes x n_candidates]
# in case on anchor is assigned to multiple boxes, use box with highest IoU
# TODO: think about a better way to do this
for ng in range(num_gt):
candidate_idx[ng, :] += ng * num_anchors
ious_inf = torch.full_like(match_quality_matrix, -INF).view(-1)
index = candidate_idx.view(-1)[is_pos.view(-1)]
ious_inf[index] = match_quality_matrix.view(-1)[index]
ious_inf = ious_inf.view_as(match_quality_matrix)
matched_vals, matches = ious_inf.max(dim=0)
matches[matched_vals == -INF] = self.BELOW_LOW_THRESHOLD
# print(f"Num matches {(matches >= 0).sum()}, Adapt IoU {iou_thresh_per_gt}")
return match_quality_matrix, matches
MatcherType = TypeVar('MatcherType', bound=Matcher)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from torch import Tensor
from torch.cuda.amp import autocast
from torchvision.ops.boxes import nms as nms_2d
from nndet._C import nms as nms_gpu
from nndet.core.boxes.utils import box_iou
def nms_cpu(boxes, scores, thresh):
"""
Performs non-maximum suppression for 3d boxes on cpu
Args:
boxes (Tensor): tensor with boxes (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
scores (Tensor): score for each box [N]
iou_threshold (float): threshould when boxes are discarded
Returns:
keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS,
sorted in decreasing order of scores
"""
ious = box_iou(boxes, boxes)
_, _idx = torch.sort(scores, descending=True)
keep = []
while _idx.nelement() > 0:
keep.append(_idx[0])
# get all elements that were not matched and discard all others.
non_matches = torch.where((ious[_idx[0]][_idx] <= thresh))[0]
_idx = _idx[non_matches]
return torch.tensor(keep).to(boxes).long()
@autocast(enabled=False)
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float):
"""
Performs non-maximum suppression
Args:
boxes (Tensor): tensor with boxes (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
scores (Tensor): score for each box [N]
iou_threshold (float): threshould when boxes are discarded
Returns:
keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS,
sorted in decreasing order of scores
"""
if boxes.shape[1] == 4:
# prefer torchvision in 2d because they have c++ cpu version
nms_fn = nms_2d
else:
if boxes.is_cuda:
nms_fn = nms_gpu
else:
nms_fn = nms_cpu
return nms_fn(boxes.float(), scores.float(), iou_threshold)
def batched_nms(boxes: Tensor, scores: Tensor, idxs: Tensor, iou_threshold: float):
"""
Performs non-maximum suppression in a batched fashion.
Each index value correspond to a category, and NMS
will not be applied between elements of different categories.
Args:
boxes (Tensor): boxes where NMS will be performed. (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
scores (Tensor): scores for each one of the boxes [N]
idxs (Tensor): indices of the categories for each one of the boxes. [N]
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
Returns
keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS,
sorted in decreasing order of scores
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
# strategy: in order to perform NMS independently per class.
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
# from different classes do not overlap
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + 1)
boxes_for_nms = boxes + offsets[:, None]
return nms(boxes_for_nms, scores, iou_threshold)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from loguru import logger
from abc import ABC
from typing import List
from torch import Tensor
from torchvision.models.detection._utils import BalancedPositiveNegativeSampler
class AbstractSampler(ABC):
def __call__(self, target_labels: List[Tensor], fg_probs: Tensor):
"""
Select positive and negative anchors
Args:
target_labels (List[Tensor]): labels for each anchor per image, List[[A]],
where A is the number of anchors in one image
fg_probs (Tensor): maximum foreground probability per anchor, [R]
where R is the sum of all anchors inside one batch
Returns:
List[Tensor]: binary mask for positive anchors, List[[A]]
List[Tensor]: binary mask for negative anchors, List[[A]]
"""
raise NotImplementedError
class NegativeSampler(BalancedPositiveNegativeSampler, AbstractSampler):
def __call__(self, target_labels: List[Tensor], fg_probs: Tensor):
"""
Randomly sample negatives and positives until batch_size_per_img
is reached
If not enough positive samples are found, it will be padded with
negative samples
"""
return super(NegativeSampler, self).__call__(target_labels)
class HardNegativeSamplerMixin(ABC):
def __init__(self, pool_size: float = 10):
"""
Create a pool from the highest scoring false positives and sample
defined number of negatives from it
Args:
pool_size (float): hard negatives are sampled from a pool of size:
batch_size_per_image * (1 - positive_fraction) * pool_size
"""
self.pool_size = pool_size
def select_negatives(self, negative: Tensor, num_neg: int,
img_labels: Tensor, img_fg_probs: Tensor):
"""
Select negative anchors
Args:
negative (Tensor): indices of negative anchors [P],
where P is the number of negative anchors
num_neg (int): number of negative anchors to sample
img_labels (Tensor): labels for all anchors in a image [A],
where A is the number of anchors in one image
img_fg_probs (Tensor): maximum foreground probability per anchor [A],
where A is the the number of anchors in one image
Returns:
Tensor: binary mask of negative anchors to choose [A],
where A is the the number of anchors in one image
"""
pool = int(num_neg * self.pool_size)
pool = min(negative.numel(), pool) # protect against not enough negatives
# select pool of highest scoring false positives
_, negative_idx_pool = img_fg_probs[negative].topk(pool, sorted=True)
negative = negative[negative_idx_pool]
# select negatives from pool
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
neg_idx_per_image = negative[perm2]
neg_idx_per_image_mask = torch.zeros_like(img_labels, dtype=torch.uint8)
neg_idx_per_image_mask[neg_idx_per_image] = 1
return neg_idx_per_image_mask
class HardNegativeSampler(HardNegativeSamplerMixin):
def __init__(self, batch_size_per_image: int, positive_fraction: float,
min_neg: int = 0, pool_size: float = 10):
"""
Created a pool from the highest scoring false positives and sample
defined number of negatives from it
Args:
batch_size_per_image (int): number of elements to be selected per image
positive_fraction (float): percentage of positive elements per batch
pool_size (float): hard negatives are sampled from a pool of size:
batch_size_per_image * (1 - positive_fraction) * pool_size
"""
super().__init__(pool_size=pool_size)
self.min_neg = min_neg
self.batch_size_per_image = batch_size_per_image
self.positive_fraction = positive_fraction
def __call__(self, target_labels: List[Tensor], fg_probs: Tensor):
"""
Select hard negatives from list anchors per image
Args:
target_labels (List[Tensor]): labels for each anchor per image, List[[A]],
where A is the number of anchors in one image
fg_probs (Tensor): maximum foreground probability per anchor, [R]
where R is the sum of all anchors inside one batch
Returns:
List[Tensor]: binary mask for positive anchors, List[[A]]
List[Tensor]: binary mask for negative anchors, List[[A]]
"""
anchors_per_image = [anchors_in_image.shape[0] for anchors_in_image in target_labels]
fg_probs = fg_probs.split(anchors_per_image, 0)
pos_idx = []
neg_idx = []
for img_labels, img_fg_probs in zip(target_labels, fg_probs):
positive = torch.where(img_labels >= 1)[0]
negative = torch.where(img_labels == 0)[0]
num_pos = self.get_num_pos(positive)
pos_idx_per_image_mask = self.select_positives(
positive, num_pos, img_labels, img_fg_probs)
pos_idx.append(pos_idx_per_image_mask)
num_neg = self.get_num_neg(negative, num_pos)
neg_idx_per_image_mask = self.select_negatives(
negative, num_neg, img_labels, img_fg_probs)
neg_idx.append(neg_idx_per_image_mask)
return pos_idx, neg_idx
def get_num_pos(self, positive: torch.Tensor) -> int:
"""
Number of positive samples to draw
Args:
positive: indices of positive anchors
Returns:
int: number of postive sample
"""
# positive anchor sampling
num_pos = int(self.batch_size_per_image * self.positive_fraction)
# protect against not enough positive examples
num_pos = min(positive.numel(), num_pos)
return num_pos
def get_num_neg(self, negative: torch.Tensor, num_pos: int) -> int:
"""
Sample enough negatives to fill up :param:`self.batch_size_per_image`
Args:
negative: indices of positive anchors
num_pos: number of positive samples to draw
Returns:
int: number of negative samples
"""
# always assume at least one pos anchor was sampled
num_neg = int(max(1, num_pos) * abs(1 - 1. / float(self.positive_fraction)))
# protect against not enough negative examples and sample at least one neg if possible
num_neg = min(negative.numel(), max(num_neg, self.min_neg))
return num_neg
def select_positives(self, positive: Tensor, num_pos: int,
img_labels: Tensor, img_fg_probs: Tensor):
"""
Select positive anchors
Args:
positive (Tensor): indices of positive anchors [P],
where P is the number of positive anchors
num_pos (int): number of positive anchors to sample
img_labels (Tensor): labels for all anchors in a image [A],
where A is the number of anchors in one image
img_fg_probs (Tensor): maximum foreground probability per anchor [A],
where A is the the number of anchors in one image
Returns:
Tensor: binary mask of positive anchors to choose [A],
where A is the the number of anchors in one image
"""
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
pos_idx_per_image = positive[perm1]
pos_idx_per_image_mask = torch.zeros_like(img_labels, dtype=torch.uint8)
pos_idx_per_image_mask[pos_idx_per_image] = 1
return pos_idx_per_image_mask
class HardNegativeSamplerBatched(HardNegativeSampler):
"""
Samples negatives and positives on a per batch basis
(default sampler only does this on a per image basis)
Note:
:attr:`batch_size_per_image` is manipulated to sample the correct
number of samples per batch, use :attr:`_batch_size_per_image`
to get the number of anchors per image
"""
def __init__(self, batch_size_per_image: int, positive_fraction: float,
min_neg: int = 0, pool_size: float = 10):
"""
Args:
batch_size_per_image (int): number of elements to be selected per image
positive_fraction (float): percentage of positive elements per batch
pool_size (float): hard negatives are sampled from a pool of size:
batch_size_per_image * (1 - positive_fraction) * pool_size
"""
super().__init__(min_neg=min_neg, batch_size_per_image=batch_size_per_image,
positive_fraction=positive_fraction, pool_size=pool_size)
self._batch_size_per_image = batch_size_per_image
logger.info("Sampling hard negatives on a per batch basis")
def __call__(self, target_labels: List[Tensor], fg_probs: Tensor):
"""
Select hard negatives from list anchors per image
Args:
target_labels (List[Tensor]): labels for each anchor per image, List[[A]],
where A is the number of anchors in one image
fg_probs (Tensor): maximum foreground probability per anchor, [R]
where R is the sum of all anchors inside one batch
Returns:
List[Tensor]: binary mask for positive anchors, List[[A]]
List[Tensor]: binary mask for negative anchors, List[[A]]
"""
batch_size = len(target_labels)
self.batch_size_per_image = self._batch_size_per_image * batch_size
target_labels_batch = torch.cat(target_labels, dim=0)
positive = torch.where(target_labels_batch >= 1)[0]
negative = torch.where(target_labels_batch == 0)[0]
num_pos = self.get_num_pos(positive)
pos_idx = self.select_positives(
positive, num_pos, target_labels_batch, fg_probs)
num_neg = self.get_num_neg(negative, num_pos)
neg_idx = self.select_negatives(
negative, num_neg, target_labels_batch, fg_probs)
# Comb Head with sampling concatenates masks after sampling so do not split them here
# anchors_per_image = [anchors_in_image.shape[0] for anchors_in_image in target_labels]
# return pos_idx.split(anchors_per_image, 0), neg_idx.split(anchors_per_image, 0)
return [pos_idx], [neg_idx]
class BalancedHardNegativeSampler(HardNegativeSampler):
def get_num_neg(self, negative: torch.Tensor, num_pos: int) -> int:
"""
Sample same number of negatives as positives but at least one
Args:
negative: indices of positive anchors
num_pos: number of positive samples to draw
Returns:
int: number of negative samples
"""
# protect against not enough negative examples and sample at least one neg if possible
num_neg = min(negative.numel(), max(num_pos, 1))
return num_neg
class HardNegativeSamplerFgAll(HardNegativeSamplerMixin):
def __init__(self, negative_ratio: float = 1, pool_size: float = 10):
"""
Use all positive anchors for loss and sample corresponding number
of hard negatives
Args:
negative_ratio (float): ratio of negative to positive sample;
(samples negative_ratio * positive_anchors examples)
pool_size (float): hard negatives are sampled from a pool of size:
batch_size_per_image * (1 - positive_fraction) * pool_size
"""
super().__init__(pool_size=pool_size)
self.negative_ratio = negative_ratio
def __call__(self, target_labels: List[Tensor], fg_probs: Tensor):
"""
Select hard negatives from list anchors per image
Args:
target_labels (List[Tensor]): labels for each anchor per image, List[[A]],
where A is the number of anchors in one image
fg_probs (Tensor): maximum foreground probability per anchor, [R]
where R is the sum of all anchors inside one batch
Returns:
List[Tensor]: binary mask for positive anchors, List[[A]]
List[Tensor]: binary mask for negative anchors, List[[A]]
"""
anchors_per_image = [anchors_in_image.shape[0] for anchors_in_image in target_labels]
fg_probs = fg_probs.split(anchors_per_image, 0)
pos_idx = []
neg_idx = []
for img_labels, img_fg_probs in zip(target_labels, fg_probs):
negative = torch.where(img_labels == 0)[0]
# positive anchor sampling
pos_idx_per_image_mask = (img_labels >= 1).to(dtype=torch.uint8)
pos_idx.append(pos_idx_per_image_mask)
num_neg = int(self.negative_ratio * pos_idx_per_image_mask.sum())
# protect against not enough negative examples and sample at least one neg if possible
num_neg = min(negative.numel(), max(num_neg, 1))
neg_idx_per_image_mask = self.select_negatives(
negative, num_neg, img_labels, img_fg_probs)
neg_idx.append(neg_idx_per_image_mask)
return pos_idx, neg_idx
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from torch import Tensor
from numpy import ndarray
from typing import Union, Sequence, Tuple
from torch.cuda.amp import autocast
def box_area_3d(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its
(x1, y1, x2, y2, z1, z2) coordinates.
Arguments:
boxes (Union[Tensor, ndarray]): boxes for which the area will be computed. They
are expected to be in (x1, y1, x2, y2, z1, z2) format. [N, 6]
Returns:
area (Union[Tensor, ndarray]): area for each box [N]
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4])
def box_area_2d(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its
(x1, y1, x2, y2) coordinates.
Arguments:
boxes (Union[Tensor, ndarray]): boxes for which the area will be computed. They
are expected to be in (x1, y1, x2, y2) format. [N, 4]
Returns:
area (Union[Tensor, ndarray]): area for each box [N]
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
def box_area(boxes: Union[Tensor, ndarray]) -> Union[Tensor, ndarray]:
"""
Computes the area of a set of bounding boxes
Args:
boxes (Union[Tensor, ndarray]): boxes of shape; (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
Returns:
Union[Tensor, ndarray]: area of boxes
See Also:
:func:`box_area_3d`, :func:`torchvision.ops.boxes.box_area`
"""
if boxes.shape[-1] == 4:
return box_area_2d(boxes)
else:
return box_area_3d(boxes)
@autocast(enabled=False)
def box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tensor:
"""
Return intersection-over-union (Jaccard index) of boxes.
(Works for Tensors and Numpy Arrays)
Arguments:
boxes1: boxes; (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
boxes2: boxes; (x1, y1, x2, y2, (z1, z2))[M, dim * 2]
eps: optional small constant for numerical stability
Returns:
iou (Tensor): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2; [N, M]
See Also:
:func:`box_iou_3d`, :func:`torchvision.ops.boxes.box_iou`
Notes:
Need to compute IoU in float32 (autocast=False) because the
volume/area can be to large
"""
# TODO: think about adding additional assert statements to check coordinates x1 <= x2, y1 <= y2, z1 <= z2
if boxes1.numel() == 0 or boxes2.numel() == 0:
return torch.tensor([]).to(boxes1)
if boxes1.shape[-1] == 4:
return box_iou_union_2d(boxes1.float(), boxes2.float(), eps=eps)[0]
else:
return box_iou_union_3d(boxes1.float(), boxes2.float(), eps=eps)[0]
@autocast(enabled=False)
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tensor:
"""
Generalized box iou
Arguments:
boxes1: boxes; (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
boxes2: boxes; (x1, y1, x2, y2, (z1, z2))[M, dim * 2]
eps: optional small constant for numerical stability
Returns:
Tensor: the NxM matrix containing the pairwise
generalized IoU values for every element in boxes1 and boxes2; [N, M]
Notes:
Need to compute IoU in float32 (autocast=False) because the
volume/area can be to large
"""
if boxes1.nelement() == 0 or boxes2.nelement() == 0:
return torch.tensor([]).to(boxes1)
if boxes1.shape[-1] == 4:
return generalized_box_iou_2d(boxes1.float(), boxes2.float(), eps=eps)
else:
return generalized_box_iou_3d(boxes1.float(), boxes2.float(), eps=eps)
def box_iou_union_3d(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tuple[Tensor, Tensor]:
"""
Return intersection-over-union (Jaccard index) and of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2, z1, z2) format.
Args:
boxes1: set of boxes (x1, y1, x2, y2, z1, z2)[N, 6]
boxes2: set of boxes (x1, y1, x2, y2, z1, z2)[M, 6]
eps: optional small constant for numerical stability
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
Tensor[N, M]: the nxM matrix containing the pairwise union
values
"""
vol1 = box_area_3d(boxes1)
vol2 = box_area_3d(boxes2)
x1 = torch.max(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = torch.max(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = torch.min(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = torch.min(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
z1 = torch.max(boxes1[:, None, 4], boxes2[:, 4]) # [N, M]
z2 = torch.min(boxes1[:, None, 5], boxes2[:, 5]) # [N, M]
inter = ((x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0) * (z2 - z1).clamp(min=0)) + eps # [N, M]
union = (vol1[:, None] + vol2 - inter)
return inter / union, union
def generalized_box_iou_3d(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tensor:
"""
Computes the generalized box iou between given bounding boxes
Args:
boxes1: set of boxes (x1, y1, x2, y2, z1, z2)[N, 6]
boxes2: set of boxes (x1, y1, x2, y2, z1, z2)[M, 6]
eps: optional small constant for numerical stability
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise
generalized IoU values for every element in boxes1 and boxes2
"""
iou, union = box_iou_union_3d(boxes1, boxes2)
x1 = torch.min(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = torch.min(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = torch.max(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = torch.max(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
z1 = torch.min(boxes1[:, None, 4], boxes2[:, 4]) # [N, M]
z2 = torch.max(boxes1[:, None, 5], boxes2[:, 5]) # [N, M]
vol = ((x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0) * (z2 - z1).clamp(min=0)) + eps # [N, M]
return iou - (vol - union) / vol
def box_iou_union_2d(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tuple[Tensor, Tensor]:
"""
Return intersection-over-union (Jaccard index) and of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
boxes1: set of boxes (x1, y1, x2, y2)[N, 4]
boxes2: set of boxes (x1, y1, x2, y2)[M, 4]
eps: optional small constant for numerical stability
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
union (Tensor[N, M]): the nxM matrix containing the pairwise union
values
"""
area1 = box_area(boxes1)
area2 = box_area(boxes2)
x1 = torch.min(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = torch.min(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = torch.max(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = torch.max(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
inter = ((x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)) + eps # [N, M]
union = (area1[:, None] + area2 - inter)
return inter / union, union
def generalized_box_iou_2d(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tensor:
"""
Computes the generalized box iou between given bounding boxes
Args:
boxes1: set of boxes (x1, y1, x2, y2)[N, 4]
boxes2: set of boxes (x1, y1, x2, y2)[M, 4]
eps: optional small constant for numerical stability
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise
generalized IoU values for every element in boxes1 and boxes2
"""
iou, union = box_iou_union_2d(boxes1, boxes2)
x1 = torch.min(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = torch.min(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = torch.max(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = torch.max(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
area = ((x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)) + eps # [N, M]
return iou - (area - union) / area
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
"""
Remove boxes with at least one side smaller than min_size.
Arguments:
boxes (Tensor): boxes (x1, y1, x2, y2, (z1, z2)) [N, dim * 2]
min_size (float): minimum size
Returns:
keep (Tensor): indices of the boxes that have both sides
larger than min_size [N]
"""
if boxes.shape[1] == 4:
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
keep = (ws >= min_size) & (hs >= min_size)
else:
ws, hs, ds = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1], boxes[:, 5] - boxes[:, 4]
keep = (ws >= min_size) & (hs >= min_size) & (ds >= min_size)
keep = torch.where(keep)[0]
return keep
def box_center_dist(boxes1: Tensor, boxes2: Tensor, euclidean: bool = True) -> \
Tuple[Tensor, Tensor, Tensor]:
"""
Distance of center points between two sets of boxes
Arguments:
boxes1: boxes; (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
boxes2: boxes; (x1, y1, x2, y2, (z1, z2))[M, dim * 2]
euclidean: computed the euclidean distance otherwise it uses the l1
distance
Returns:
Tensor: the NxM matrix containing the pairwise
distances for every element in boxes1 and boxes2; [N, M]
Tensor: center points of boxes1
Tensor: center points of boxes2
"""
center1 = box_center(boxes1) # [N, dims]
center2 = box_center(boxes2) # [M, dims]
if euclidean:
dists = (center1[:, None] - center2[None]).pow(2).sum(-1).sqrt()
else:
# before sum: [N, M, dims]
dists = (center1[:, None] - center2[None]).sum(-1)
return dists, center1, center2
def center_in_boxes(center: Tensor, boxes: Tensor, eps: float = 0.01) -> Tensor:
"""
Checks which center points are within boxes
Args:
center: center points [N, dims]
boxes: boxes [N, dims * 2]
eps: minimum distance to boarder of boxes
Returns:
Tensor: boolean array indicating which center points are within
the boxes [N]
"""
axes = []
axes.append(center[:, 0] - boxes[:, 0])
axes.append(center[:, 1] - boxes[:, 1])
axes.append(boxes[:, 2] - center[:, 0])
axes.append(boxes[:, 3] - center[:, 1])
if center.shape[1] == 3:
axes.append(center[:, 2] - boxes[:, 4])
axes.append(boxes[:, 5] - center[:, 2])
return torch.stack(axes, dim=1).min(dim=1)[0] > eps
def box_center(boxes: Tensor) -> Tensor:
"""
Compute center point of boxes
Args:
boxes: bounding boxes (x1, y1, x2, y2, (z1, z2)) [N, dims * 2]
Returns:
Tensor: center points [N, dims]
"""
centers = [(boxes[:, 2] + boxes[:, 0]) / 2., (boxes[:, 3] + boxes[:, 1]) / 2.]
if boxes.shape[1] == 6:
centers.append((boxes[:, 5] + boxes[:, 4]) / 2.)
return torch.stack(centers, dim=1)
def permute_boxes(boxes: Union[Tensor, ndarray],
dims: Sequence[int] = None) -> Union[Tensor, ndarray]:
"""
Change ordering of axis of boxes
Args:
boxes: boxes [N, dims * 2](x1, y1, x2, y2(, z1, z2))
dims: the desired ordering of dimensions; By default the dimensions
are reversed
Returns:
Tensor: boxes with permuted axes [N, dims * 2]
"""
if dims is None:
dims = list(range(boxes.shape[1] // 2))[::-1]
if 2 * len(dims) != boxes.shape[1]:
raise TypeError(f"Need same number of dimensions, found dims {dims} "
f"but boxes with shape {boxes.shape}")
indexing = [[0, 2], [1, 3]]
if boxes.shape[1] == 6:
indexing.append([4, 5])
new_axis = [indexing[dims[0]][0], indexing[dims[1]][0],
indexing[dims[0]][1], indexing[dims[1]][1]]
for d in dims[2:]:
new_axis.extend(indexing[d])
return boxes[:, new_axis]
def expand_to_boxes(data: Union[Tensor, ndarray]) -> Union[Tensor, ndarray]:
"""
Expand x,y,z data to box format
Args:
data (Tensor): data to expand (N, dim)[:, (x, y, [z])]
Returns:
Tensor: expanded tensors
"""
idx = [0, 1, 0, 1]
if (len(data.shape) == 1 and data.shape[0] == 3) or (len(data.shape) == 2 and data.shape[1] == 3):
idx.extend((2, 2))
if len(data.shape) == 1:
data = data[None]
return data[:, idx]
def box_size(boxes: Tensor) -> Tensor:
"""
Compute length of boxes along all dimensions
Args:
boxes (Tensor): boxes (x1, y1, x2, y2, z1, z2)[N, dim * 2]
Returns:
Tensor: size along axis (x, y, (z))[N, dim]
"""
dists = []
dists.append(boxes[:, 2] - boxes[:, 0])
dists.append(boxes[:, 3] - boxes[:, 1])
if boxes.shape[1] // 2 == 3:
dists.append(boxes[:, 5] - boxes[:, 4])
return torch.stack(dists, axis=1)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from numpy import ndarray
def box_area_np(boxes: ndarray) -> ndarray:
"""
See Also:
:func:`nndet.core.boxes.utils.box_area`
"""
if boxes.shape[-1] == 4:
return box_area_2d_np(boxes)
else:
return box_area_3d_np(boxes)
def box_area_3d_np(boxes: np.ndarray) -> np.ndarray:
"""
See Also:
`nndet.core.boxes.utils.box_area_3d`
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4])
def box_area_2d_np(boxes: np.ndarray) -> np.ndarray:
"""
See Also:
`nndet.core.boxes.utils.box_area_2d`
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
def box_iou_np(boxes1: ndarray, boxes2: ndarray) -> ndarray:
"""
Return intersection-over-union (Jaccard index) of boxes.
(Works for ndarrays and Numpy Arrays)
Arguments:
boxes1 (ndarray): boxes; (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
boxes2 (ndarray): boxes; (x1, y1, x2, y2, (z1, z2))[M, dim * 2]
Returns:
iou (ndarray): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2; [N, M]
See Also:
:func:`box_iou_3d`, :func:`torchvision.ops.boxes.box_iou`
"""
# TODO: think about adding additional assert statements to check coordinates x1 <= x2, y1 <= y2, z1 <= z2
if boxes1.shape[-1] == 4:
return box_iou_2d_np(boxes1, boxes2)
else:
return box_iou_3d_np(boxes1, boxes2)
def box_iou_2d_np(boxes1: ndarray, boxes2: ndarray) -> ndarray:
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
boxes1 (ndarray): set of boxes (x1, y1, x2, y2)[N, 4]
boxes2 (ndarray): set of boxes (x1, y1, x2, y2)[M, 4]
Returns:
iou (ndarray[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
area1 = box_area_2d_np(boxes1)
area2 = box_area_2d_np(boxes2)
x1 = np.maximum(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = np.maximum(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = np.minimum(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = np.minimum(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
inter = np.clip((x2 - x1), a_min=0, a_max=None) * np.clip((y2 - y1), a_min=0, a_max=None) # [N, M]
return inter / (area1[:, None] + area2 - inter)
def box_iou_3d_np(boxes1: ndarray, boxes2: ndarray) -> ndarray:
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2, z1, z2) format.
Arguments:
boxes1 (ndarray): set of boxes (x1, y1, x2, y2, z1, z2)[N, 6]
boxes2 (ndarray): set of boxes (x1, y1, x2, y2, z1, z2)[M, 6]
Returns:
iou (ndarray[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
area1 = box_area_3d_np(boxes1)
area2 = box_area_3d_np(boxes2)
x1 = np.maximum(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = np.maximum(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = np.minimum(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = np.minimum(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
z1 = np.maximum(boxes1[:, None, 4], boxes2[:, 4]) # [N, M]
z2 = np.minimum(boxes1[:, None, 5], boxes2[:, 5]) # [N, M]
inter = np.clip((x2 - x1), a_min=0, a_max=None) * np.clip((y2 - y1), a_min=0, a_max=None) * \
np.clip((z2 - z1), a_min=0, a_max=None) # [N, M]
return inter / (area1[:, None] + area2 - inter)
def box_size_np(boxes: ndarray) -> ndarray:
"""
Compute length of boxes along all dimensions
Args:
boxes (ndarray): boxes (x1, y1, x2, y2, z1, z2)[N, dim * 2]
Returns:
ndarray: size along axis (x, y, (z))[N, dim]
"""
dists = []
dists.append(boxes[:, 2] - boxes[:, 0])
dists.append(boxes[:, 3] - boxes[:, 1])
if boxes.shape[1] // 2 == 3:
dists.append(boxes[:, 5] - boxes[:, 4])
return np.stack(dists, axis=-1)
def box_center_np(boxes: np.ndarray) -> np.ndarray:
"""
Compute center point of boxes
Args:
boxes: bounding boxes (x1, y1, x2, y2, (z1, z2)) [N, dims * 2]
Returns:
Tensor: center points [N, dims]
"""
centers = [(boxes[:, 2] + boxes[:, 0]) / 2., (boxes[:, 3] + boxes[:, 1]) / 2.]
if boxes.shape[1] == 6:
centers.append((boxes[:, 5] + boxes[:, 4]) / 2.)
return np.stack(centers, axis=1)
import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple, Dict, Any, Optional, Union
from nndet.models.abstract import AbstractModel
from nndet.core import boxes as box_utils
from nndet.models.heads.segmenter import SegmenterType
from nndet.models.heads.comb import HeadType
class BaseRetinaNet(AbstractModel):
def __init__(self,
dim: int,
# modules
encoder: EncoderType,
decoder: DecoderType,
head: HeadType,
num_classes: int,
anchor_generator: AnchorType,
matcher: box_utils.MatcherType,
decoder_levels: tuple = (2, 3, 4, 5),
# post-processing
score_thresh: float = None,
detections_per_img: int = 100,
topk_candidates: int = 10000,
remove_small_boxes: float = 1e-2,
nms_thresh: float = 0.9,
# optional
segmenter: Optional[SegmenterType] = None,
):
"""
Base Retina(U)Net
Can be subclasses to add specific configurations to it
Args:
dim: number of spatial dimensions
encoder: encoder module
decoder: decoder module
head: head module
num_classes: number of foreground classes
anchor_generator: generate anchors
matcher: match ground truth boxes and anchors
decoder_levels: decoder levels to use for detection prediciton
score_thresh: minimum output probability
detections_per_img: max detections per image
topk_candidates: select only topk candidates for nms computation
remove_small_boxes: remove small bounding boxes
nms_thresh: non maximum suppression threshold
segmenter: segmentation module
"""
super().__init__()
assert dim in [2, 3]
self.dim = dim
self.decoder_levels = decoder_levels
self.encoder = encoder
self.decoder = decoder
self.head = head
self.num_foreground_classes = num_classes
self.anchor_generator = anchor_generator
self.proposal_matcher = matcher
self.score_thresh = score_thresh
self.topk_candidates = topk_candidates
self.detections_per_img = detections_per_img
self.remove_small_boxes = remove_small_boxes
self.nms_thresh = nms_thresh
self.segmenter = segmenter
def train_step(self,
images: Tensor,
targets: dict,
evaluation: bool,
batch_num: int,
) -> Tuple[
Dict[str, torch.Tensor], Optional[Dict]]:
"""
Perform a single training step (forward pass + loss computation)
Args:
images: batch of images
targets: labels for training
`target_boxes` (List[Tensor]): ground truth bounding boxes
(x1, y1, x2, y2, (z1, z2))[X, dim * 2], X= number of ground
truth boxes in image
`target_classes` (List[Tensor]): ground truth class per box
(classes start from 0) [X], X= number of ground truth
boxes in image
`target_seg`(Tensor): segmentation ground truth
(only needed if :param:`segmenter`
was provided in init) (classes start from 1, 0 background)
evaluation (bool): compute final predictions (includes detection
postprocessing)
batch_num (int): batch index inside epoch
Returns:
torch.Tensor: final loss for back propagation
Dict: predictions for metric calculation
'pred_boxes': List[Tensor]: predicted bounding boxes for each
image List[[R, dim * 2]]
'pred_scores': List[Tensor]: predicted probability for the
class List[[R]]
'pred_labels': List[Tensor]: predicted class List[[R]]
'pred_seg': Tensor: predicted segmentation [N, dims]
Dict[str, torch.Tensor]: scalars for logging (e.g. individual
loss components)
"""
# import napari
# with napari.gui_qt():
# viewer = napari.view_image(images.detach().cpu().numpy())
# viewer.add_labels(seg_targets[:, None].detach().cpu().numpy())
target_boxes: List[Tensor] = targets["target_boxes"]
target_classes: List[Tensor] = targets["target_classes"]
target_seg: Tensor = targets["target_seg"]
pred_detection, anchors, pred_seg = self(images)
labels, matched_gt_boxes = self.assign_targets_to_anchors(
anchors, target_boxes, target_classes)
losses = {}
head_losses, pos_idx, neg_idx = self.head.compute_loss(
pred_detection, labels, matched_gt_boxes, anchors)
losses.update(head_losses)
if self.segmenter is not None:
losses.update(self.segmenter.compute_loss(pred_seg, target_seg))
if evaluation:
prediction = self.postprocess_for_inference(
images=images,
pred_detection=pred_detection,
pred_seg=pred_seg,
anchors=anchors,
)
else:
prediction = None
# self.save_matched_anchors(images=images, target_boxes=target_boxes,
# anchors=anchors, pos_idx=pos_idx,
# neg_idx=neg_idx, seg=seg_targets)
return losses, prediction
@torch.no_grad()
def postprocess_for_inference(self,
images: torch.Tensor,
pred_detection: Dict[str, torch.Tensor],
pred_seg: Dict[str, torch.Tensor],
anchors: List[torch.Tensor],
) -> Dict[str, Union[List[Tensor], Tensor]]:
"""
Postprocess predictions for inference
Args:
images: input images
pred_detection: detection predictions
pred_seg: segmentation predictions
anchors: anchors
Returns:
Dict: post processed predictions
'pred_boxes': List[Tensor]: predicted bounding boxes for each
image List[[R, dim * 2]]
'pred_scores': List[Tensor]: predicted probability for
the class List[[R]]
'pred_labels': List[Tensor]: predicted class List[[R]]
'pred_seg': Tensor: predicted segmentation [N, C, dims]
"""
image_shapes = [images.shape[2:]] * images.shape[0]
boxes, probs, labels = self.postprocess_detections(
pred_detection=pred_detection,
anchors=anchors,
image_shapes=image_shapes,
)
prediction = {"pred_boxes": boxes, "pred_scores": probs, "pred_labels": labels}
if self.segmenter is not None:
prediction["pred_seg"] = self.segmenter.postprocess_for_inference(pred_seg)["pred_seg"]
return prediction
def forward(self,
inp: torch.Tensor,
) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor], Dict[str, torch.Tensor]]:
"""
Compute predicted bounding boxes, scores and segmentations
Args:
inp (torch.Tensor): batch of input images
Returns:
dict: predictions from head. Typically includes:
`box_deltas`(Tensor): bounding box offsets
[Num_Anchors_Batch, (dim * 2)]
`box_logits`(Tensor): classification logits
[Num_Anchors_Batch, (num_classes)]
List[torch.Tensor]: list of anchors (for each image inside the
batch)
dict: segmentation prediction. None if retina net is configured.
Typically includes:
`seg_logits`: segmentation logits
"""
features_maps_all = self.decoder(self.encoder(inp))
feature_maps_head = [features_maps_all[i] for i in self.decoder_levels]
pred_detection = self.head(feature_maps_head)
anchors = self.anchor_generator(inp, feature_maps_head)
pred_seg = self.segmenter(features_maps_all) if self.segmenter is not None else None
return pred_detection, anchors, pred_seg
@torch.no_grad()
def assign_targets_to_anchors(self,
anchors: List[torch.Tensor],
target_boxes: List[torch.Tensor],
target_classes: List[torch.Tensor]) -> Tuple[
List[torch.Tensor], List[torch.Tensor]]:
"""
Compute labels and matched ground truth for each anchor
Adapted from torchvision https://github.com/pytorch/vision
Args:
anchors (List[torch.Tensor[float]]): anchors (!)per image(!)
List[[N, dim * 2]], N=number of anchors per image
target_boxes (List[torch.Tensor[float]]): ground truth boxes
(!)per image(!)
List[[X, dim * 2]], X=number of gt per image
target_classes (List[torch.Tensor): ground truth classes
(!)per image(!) (classes start from 0)
List[[X]], X=number of gt per image
Returns:
List[torch.Tensor]: labels ([1, K]: foreground classes, 0: background,
-1: between) List[[N]], N=number of anchors per image
List[torch.Tensor]: matched gt box List[[N, dim * 2]],
N=number of anchors per image
"""
labels = []
matched_gt_boxes = []
for anchors_per_image, gt_boxes, gt_classes in zip(anchors, target_boxes, target_classes):
# indices of ground truth box for each proposal
match_quality_matrix, matched_idxs = self.proposal_matcher(
gt_boxes, anchors_per_image,
num_anchors_per_level=self.anchor_generator.get_num_acnhors_per_level(),
num_anchors_per_loc=self.anchor_generator.num_anchors_per_location()[0])
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
if match_quality_matrix.numel() > 0:
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
# Positive (negative indices can be ignored because they are overwritten in the next step)
# this influences how background class is handled in the input!!!! (here +1 for background)
labels_per_image = gt_classes[matched_idxs.clamp(min=0)].to(dtype=anchors_per_image.dtype)
labels_per_image = labels_per_image + 1
else:
num_anchors_per_image = anchors_per_image.shape[0]
# no ground truth => no matches, all background
matched_gt_boxes_per_image = torch.zeros_like(anchors_per_image)
labels_per_image = torch.zeros(num_anchors_per_image).to(anchors_per_image)
# Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = 0.0
# discard indices that are between thresholds
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1.0
labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image)
return labels, matched_gt_boxes
def postprocess_detections(self,
pred_detection: Dict[str, Tensor],
anchors: List[Tensor],
image_shapes: List[Tuple[int]],
) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
"""
Postprocess bounding box deltas and logits to generate final boxes and
scores
Adapted from torchvision https://github.com/pytorch/vision
Args:
pred_detection: detection predictions for loss computation
`box_logits`: classification logits for each anchor [N]
`box_deltas`: offsets for each anchor
(x1, y1, x2, y2, (z1, z2))[N, dim * 2]
anchors: proposals for each image
image_shapes: shape of each image
Returns:
List[Tensor]: final boxes [R, dim * 2]
List[Tensor]: final scores (for final class) [R]
List[Tensor]: final class label [R]
"""
boxes_per_image = [len(boxes_in_image) for boxes_in_image in anchors]
pred_detection = self.head.postprocess_for_inference(pred_detection, anchors)
pred_boxes, pred_probs = pred_detection["pred_boxes"], pred_detection["pred_probs"]
# split boxes and scores per image
pred_boxes = pred_boxes.split(boxes_per_image, 0)
pred_probs = pred_probs.split(boxes_per_image, 0)
all_boxes, all_probs, all_labels = [], [], []
# iterate over images
for boxes, probs, image_shape in zip(pred_boxes, pred_probs, image_shapes):
boxes, probs, labels = self.postprocess_detections_single_image(boxes, probs, image_shape)
all_boxes.append(boxes)
all_probs.append(probs)
all_labels.append(labels)
return all_boxes, all_probs, all_labels
def postprocess_detections_single_image(
self,
boxes: Tensor,
probs: Tensor,
image_shape: Tuple[int],
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Postprocess bounding box deltas and probabilities for a single image
Adapted from torchvision https://github.com/pytorch/vision
Args:
boxes: predicted deltas for proposals [N, dim * 2]
probs: predicted logits for boxes [N, C]
image_shape: shape of image
Returns:
Tensor: final boxes [R, dim * 2]
Tensor: final scores (for final class) [R]
Tensor: final class label [R]
"""
assert boxes.shape[0] == probs.shape[0]
boxes = box_utils.clip_boxes_to_image_(boxes, image_shape)
probs = probs.flatten()
if self.topk_candidates is not None:
num_topk = min(self.topk_candidates, boxes.size(0))
probs, idx = probs.sort(descending=True)
probs, idx = probs[:num_topk], idx[:num_topk]
else:
idx = torch.arange(probs.numel())
if self.score_thresh is not None:
keep_idxs = probs > self.score_thresh
probs, idx = probs[keep_idxs], idx[keep_idxs]
anchor_idxs = idx // self.num_foreground_classes
labels = idx % self.num_foreground_classes
boxes = boxes[anchor_idxs]
if self.remove_small_boxes is not None:
keep = box_utils.remove_small_boxes(boxes, min_size=self.remove_small_boxes)
boxes, probs, labels = boxes[keep], probs[keep], labels[keep]
keep = box_utils.batched_nms(boxes, probs, labels, self.nms_thresh)
if self.detections_per_img is not None:
keep = keep[:self.detections_per_img]
return boxes[keep], probs[keep], labels[keep]
# @torch.no_grad()
# def save_matched_anchors(self, **kwargs):
# logger = get_logger("mllogger")
# logger.save_pickle("anchor_matching",
# to_device(kwargs, device="cpu", detach=True))
@torch.no_grad()
def inference_step(self,
images: Tensor,
**kwargs,
) -> Dict[str, Any]:
"""
Perform inference for a batch of images
Args:
images: batch of input images [N, C, W, H, (D)]
Returns:
Dict:
'pred_boxes': List[Tensor]: predicted bounding boxes for each
image List[[R, dim * 2]]
'pred_scores': List[Tensor]: predicted probability for
the class List[[R]]
'pred_labels': List[Tensor]: predicted class List[[R]]
'pred_seg': Tensor: predicted segmentation [N, C, dims]
"""
pred_detection, anchors, pred_seg = self(images)
prediction = self.postprocess_for_inference(
images=images,
pred_detection=pred_detection,
pred_seg=pred_seg,
anchors=anchors,
)
return prediction
......@@ -22,7 +22,7 @@ import numpy as np
from nndet.evaluator.abstract import AbstractEvaluator, DetectionMetric
from nndet.evaluator.detection.matching import matching_batch
from nndet.detection.boxes import box_iou_np
from nndet.core.boxes import box_iou_np
from nndet.evaluator.detection.coco import COCOMetric
from nndet.evaluator.detection.froc import FROCMetric
from nndet.evaluator.detection.hist import PredictionHistogram
......
......@@ -18,7 +18,7 @@ from typing import Tuple
from torch import Tensor
from nndet.detection.boxes import batched_nms, nms
from nndet.core.boxes import batched_nms, nms
from nndet.inference.detection import batched_wbc
......
......@@ -19,7 +19,7 @@ from typing import Tuple
from torch import Tensor
import torch
from nndet.detection.boxes import batched_nms
from nndet.core.boxes import batched_nms
def batched_nms_model(
......
......@@ -19,7 +19,7 @@ from typing import Tuple
import torch
from torch import Tensor
from nndet.detection.boxes import batched_nms, nms
from nndet.core.boxes import batched_nms, nms
from nndet.inference.detection import batched_wbc
......
......@@ -21,7 +21,7 @@ from typing import Tuple
from torch._C import device
from nndet.detection.boxes import box_iou, box_area
from nndet.core.boxes import box_iou, box_area
__all__ = ["batched_wbc", "wbc"]
......
......@@ -30,8 +30,8 @@ from nndet.inference.detection import batched_nms_model, batched_nms_ensemble, \
batched_wbc_ensemble, wbc_nms_no_label_ensemble
from nndet.inference.ensembler.base import BaseEnsembler, OverlapMap
from nndet.inference.restore import restore_detection
from nndet.detection.boxes import box_center, clip_boxes_to_image, remove_small_boxes
from nndet.detection.boxes.merging import GreedyIoUBoxMerger, VoteLabelGreedyIoUBoxMerger
from nndet.core.boxes import box_center, clip_boxes_to_image, remove_small_boxes
from nndet.core.boxes.merging import GreedyIoUBoxMerger, VoteLabelGreedyIoUBoxMerger
from nndet.utils.tensor import cat, to_device, to_dtype
......
......@@ -19,7 +19,7 @@ from typing import Sequence, Tuple, Union, Optional
import numpy as np
from loguru import logger
from nndet.detection.boxes.utils import permute_boxes, expand_to_boxes
from nndet.core.boxes.utils import permute_boxes, expand_to_boxes
from nndet.preprocessing.resampling import resample_data_or_seg, get_do_separate_z, get_lowres_axis
......
......@@ -27,7 +27,7 @@ from nndet.io.datamodule import DATALOADER_REGISTRY
from nndet.io.load import load_pickle
from nndet.inference.patching import save_get_crop
from nndet.utils.info import maybe_verbose_iterable
from nndet.detection.boxes.utils_np import box_size_np
from nndet.core.boxes.utils_np import box_size_np
class FixedSlimDataLoaderBase(SlimDataLoaderBase):
......
......@@ -5,7 +5,7 @@ import torch
__all__ = ["SmoothL1Loss", "smooth_l1_loss"]
from nndet.detection.boxes.utils import generalized_box_iou
from nndet.core.boxes.utils import generalized_box_iou
from nndet.losses.base import reduction_helper
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment