Commit 7246044d authored by mibaumgartner's avatar mibaumgartner
Browse files

Merge remote-tracking branch 'origin/master' into main

parents fcec502f 6f4c3333
# @package __global__
defaults:
- augmentation: mirror_only
model: "RetinaUNetC009"
trainer: "DetectionTrainerPolyLR_SGD090"
predictor: "BoxPredictorSelective"
plan: D2C002_2d
planners:
2d: [D2C002]
3d: [D2C002, D3C002] # [D3C002LR15, D3C002LR20] [D3C002NR, D3C002RibFrac] [D2C002, D3C002]
augment_cfg:
oversample_foreground_percent: 0.5 # ratio of fg and bg in batches
augmentation: ${augmentation}
dataloader: "DataLoader2DDeeplesion"
dataloader_kwargs: {}
trainer_cfg:
# Per default training is deterministic, non-deterministic allows
# cudnn.benchmark which can give up to 20% performance. Set this to false
# to perform non-deterministic training
deterministic: True
fp16: True # enable fp16 training. Makes sense for supported hardware only!
eval_score_key: "mAP_IoU_0.10_0.50_0.05_MaxDet_100" # metric to optimize
num_batches_per_epoch: 2500 # number of train batches per epoch
num_val_batches_per_epoch: 100 # number of val batches per epoch
max_num_epochs: 200 # max number of epochs
overwrites: {}
initial_lr: 0.01 # initial learning rate to start with
weight_decay: 3.e-5 # weight decay for optimizer
warmup: 4000 # number of iterations with warmup
warmup_lr: 1.e-6 # learning rate to start warmup from
model_cfg:
matching:
# IoU Matcher Parameters
fg_iou_thresh: 0.4 # IoU threshold for anchors to be matched positive
bg_iou_thresh: 0.3 # IoU threshold for anchors to be matched negative
# If ground truth has no matched anchors, use the best anchor which was found
allow_low_quality_matches: True
# ATSS matching
num_candidates: 4
center_in_gt: False
hnm: # parameters for hard negative mining
batch_size_per_image: 32 # number of anchors sampled per image
positive_fraction: 0.33 # defines ratio between positive and negative anchors
# hard negatives are sampled from a pool of size:
# batch_size_per_image * (1 - positive_fraction) * pool_size
pool_size: 20
min_neg: 1 # minimum number of negative anchors sampled per image
plan_arch_overwrites: # overwrite arguments of architecture
in_channels: 4
plan_anchors_overwrites: {} # overwrite arguments of anchors
# @package __global__
defaults:
- augmentation: base_more
model: "RetinaUNetC009LH1"
trainer: "BoxTrainer"
predictor: "BoxPredictorSelective"
plan: D3C002_3d
planners:
2d: [D2C002]
3d: [D3C003FDR] # [D3C002LR15, D3C002LR20] [D3C002NR, D3C002RibFrac] [D2C002, D3C002]
augment_cfg:
oversample_foreground_percent: 0.5 # ratio of fg and bg in batches
augmentation: ${augmentation}
dataloader: "DataLoader{}DOffset"
dataloader_kwargs: {}
trainer_cfg:
# Per default training is deterministic, non-deterministic allows
# cudnn.benchmark which can give up to 20% performance. Set this to false
# to perform non-deterministic training
deterministic: True
fp16: True # enable fp16 training. Makes sense for supported hardware only!
eval_score_key: "mAP_IoU_0.10_0.50_0.05_MaxDet_100" # metric to optimize
num_batches_per_epoch: 2500 # number of train batches per epoch
num_val_batches_per_epoch: 100 # number of val batches per epoch
max_num_epochs: 50 # max number of epochs
overwrites: {}
initial_lr: 0.01 # initial learning rate to start with
weight_decay: 3.e-5 # weight decay for optimizer
warmup: 4000 # number of iterations with warmup
warmup_lr: 1.e-6 # learning rate to start warmup from
model_cfg:
matching:
# IoU Matcher Parameters
fg_iou_thresh: 0.4 # IoU threshold for anchors to be matched positive
bg_iou_thresh: 0.3 # IoU threshold for anchors to be matched negative
# If ground truth has no matched anchors, use the best anchor which was found
allow_low_quality_matches: True
# ATSS matching
num_candidates: 4
center_in_gt: False
hnm: # parameters for hard negative mining
batch_size_per_image: 32 # number of anchors sampled per image
positive_fraction: 0.33 # defines ratio between positive and negative anchors
# hard negatives are sampled from a pool of size:
# batch_size_per_image * (1 - positive_fraction) * pool_size
pool_size: 20
min_neg: 1 # minimum number of negative anchors sampled per image
plan_arch_overwrites: {} # overwrite arguments of architecture
plan_anchors_overwrites: {} # overwrite arguments of anchors
# @package __global__
defaults:
- augmentation: base_more
model: "RetinaUNetC009LH1"
trainer: "BoxTrainerSWA"
predictor: "BoxPredictorSelective"
plan: D3C002_3d
planners:
2d: [D2C002]
3d: [D3C003FD] # [D3C002LR15, D3C002LR20] [D3C002NR, D3C002RibFrac] [D2C002, D3C002]
augment_cfg:
oversample_foreground_percent: 0.5 # ratio of fg and bg in batches
augmentation: ${augmentation}
dataloader: "DataLoader{}DOffset"
dataloader_kwargs: {}
trainer_cfg:
# Per default training is deterministic, non-deterministic allows
# cudnn.benchmark which can give up to 20% performance. Set this to false
# to perform non-deterministic training
deterministic: True
fp16: True # enable fp16 training. Makes sense for supported hardware only!
eval_score_key: "mAP_IoU_0.10_0.50_0.05_MaxDet_100" # metric to optimize
num_batches_per_epoch: 2500 # number of train batches per epoch
num_val_batches_per_epoch: 100 # number of val batches per epoch
max_num_epochs: 60 # max number of epochs
overwrites: {}
initial_lr: 0.01 # initial learning rate to start with
weight_decay: 3.e-5 # weight decay for optimizer
warmup: 4000 # number of iterations with warmup
warmup_lr: 1.e-6 # learning rate to start warmup from
swa_epochs: 10 # number of epochs to run swa with cyclic learning rate
swa_snapshots: 10 # number of swa snapshots
model_cfg:
matching:
# IoU Matcher Parameters
fg_iou_thresh: 0.4 # IoU threshold for anchors to be matched positive
bg_iou_thresh: 0.3 # IoU threshold for anchors to be matched negative
# If ground truth has no matched anchors, use the best anchor which was found
allow_low_quality_matches: True
# ATSS matching
num_candidates: 4
center_in_gt: False
hnm: # parameters for hard negative mining
batch_size_per_image: 32 # number of anchors sampled per image
positive_fraction: 0.33 # defines ratio between positive and negative anchors
# hard negatives are sampled from a pool of size:
# batch_size_per_image * (1 - positive_fraction) * pool_size
pool_size: 20
min_neg: 1 # minimum number of negative anchors sampled per image
plan_arch_overwrites: {} # overwrite arguments of architecture
plan_anchors_overwrites: {} # overwrite arguments of anchors
# @package __global__
defaults:
- augmentation: base_more
module: RetinaUNetV001
predictor: BoxPredictorSelective
plan: D3V001_3d # plan used for training
planner: D3V001 # planner used for preprocessing
augment_cfg:
augmentation: ${augmentation}
num_train_batches_per_epoch: ${trainer_cfg.num_train_batches_per_epoch}
num_val_batches_per_epoch: ${trainer_cfg.num_val_batches_per_epoch}
dataloader: "DataLoader{}DOffset"
oversample_foreground_percent: 0.5 # ratio of fg and bg in batches
dataloader_kwargs: {}
num_threads: ${oc.env:det_num_threads, "12"}
num_cached_per_thread: 2
multiprocessing: True # only deactivate this if debugging
trainer_cfg:
gpus: 1 # number of gpus
accelerator: ddp # distributed backend
precision: 16 # mixed precision
amp_backend: native # mixed precision backend
amp_level: O1 # when mixed precision backend is APEX use O1
# Per default training is deterministic, non-deterministic allows
# cudnn.benchmark which can give up to 20% performance. Set this to false
# to perform non-deterministic training
deterministic: False
benchmark: False
# fp16: True # enable fp16 training. Makes sense for supported hardware only!
monitor_key: "mAP_IoU_0.10_0.50_0.05_MaxDet_100" # used to determine the best model
monitor_mode: "max" # metric operation mode "min" or "max"
max_num_epochs: 2 # max number of epochs
num_train_batches_per_epoch: 20 # number of train batches per epoch
num_val_batches_per_epoch: 10 # number of val batches per epoch
initial_lr: 0.01 # initial learning rate to start with
sgd_momentum: 0.9 # momentum term
sgd_nesterov: True # nesterov momentum
weight_decay: 3.e-5 # weight decay for optimizer
momentum: 0.9 # momentum term
warm_iterations: 4000 # number of iterations with warmup
warm_lr: 1.e-6 # learning rate to start warmup from
poly_gamma: 0.9
swa_epochs: 2 # number of epochs to run swa with cyclic learning rate
model_cfg:
encoder_kwargs: {} # keyword arguments passed to encoder
decoder_kwargs: # keyword arguments passed to decoder
min_out_channels: 8
upsampling_mode: "transpose"
num_lateral: 1
norm_lateral: False
activation_lateral: False
num_out: 1
norm_out: False
activation_out: False
head_kwargs: {} # keyword arguments to passed to head
head_classifier_kwargs: # keyword arguments passed to classifier in head
num_convs: 2
norm_channels_per_group: 16
norm_affine: True
reduction: "mean"
loss_weight: 1.
# gamma: 1.
# alpha: 0.75
# reduction: "sum"
# loss_weight: 0.3
prior_prob: 0.01
head_regressor_kwargs: # keyword arguments passed to regressor in head
num_convs: 2
norm_channels_per_group: 16
norm_affine: True
reduction: "sum"
loss_weight: 1.
learn_scale: True
head_sampler_kwargs: # keyword arguments passed to sampler
batch_size_per_image: 32 # number of anchors sampled per image
positive_fraction: 0.33 # defines ratio between positive and negative anchors
# hard negatives are sampled from a pool of size:
# batch_size_per_image * (1 - positive_fraction) * pool_size
pool_size: 20
min_neg: 1 # minimum number of negative anchors sampled per image
segmenter_kwargs:
dice_kwargs:
batch_dice: True
matcher_kwargs: # keyword arguments passed to matcher
num_candidates: 4
center_in_gt: False
plan_arch_overwrites: {} # overwrite arguments of architecture
plan_anchors_overwrites: {} # overwrite arguments of anchors
debug:
num_cases_val: 2 # only predict two cases for validation results
# @package __global__
defaults:
- augmentation: base_more
module: RetinaUNetV001
predictor: BoxPredictorSelective
plan: D3V001_3d # plan used for training
planner: D3V001 # planner used for preprocessing
augment_cfg:
augmentation: ${augmentation}
num_train_batches_per_epoch: ${trainer_cfg.num_train_batches_per_epoch}
num_val_batches_per_epoch: ${trainer_cfg.num_val_batches_per_epoch}
dataloader: "DataLoader{}DOffset"
oversample_foreground_percent: 0.5 # ratio of fg and bg in batches
dataloader_kwargs: {}
num_threads: ${oc.env:det_num_threads, "12"}
num_cached_per_thread: 2
multiprocessing: True # only deactivate this if debugging
# Additional overwrites
# patch_size; Default: plan
# batch_size; Default: plan
# splits; Default: splits_final
trainer_cfg:
gpus: 1 # number of gpus
accelerator: ddp # distributed backend
precision: 16 # mixed precision
amp_backend: native # mixed precision backend
amp_level: O1 # when mixed precision backend is APEX use O1
# Per default training is deterministic, non-deterministic allows
# cudnn.benchmark which can give up to 20% performance. Set this to false
# to perform non-deterministic training
deterministic: False
benchmark: False
# fp16: True # enable fp16 training. Makes sense for supported hardware only!
monitor_key: "mAP_IoU_0.10_0.50_0.05_MaxDet_100" # used to determine the best model
monitor_mode: "max" # metric operation mode "min" or "max"
max_num_epochs: 50 # max number of epochs
num_train_batches_per_epoch: 2500 # number of train batches per epoch
num_val_batches_per_epoch: 100 # number of val batches per epoch
initial_lr: 0.01 # initial learning rate to start with
sgd_momentum: 0.9 # momentum term
sgd_nesterov: True # nesterov momentum
weight_decay: 3.e-5 # weight decay for optimizer
warm_iterations: 4000 # number of iterations with warmup
warm_lr: 1.e-6 # learning rate to start warmup from
poly_gamma: 0.9
swa_epochs: 10 # number of epochs to run swa with cyclic learning rate
# sweep_ckpt: Select checkpoint identifier for sweeping. Default "last".
model_cfg:
encoder_kwargs: {} # keyword arguments passed to encoder
decoder_kwargs: # keyword arguments passed to decoder
min_out_channels: 8
upsampling_mode: "transpose"
num_lateral: 1
norm_lateral: False
activation_lateral: False
num_out: 1
norm_out: False
activation_out: False
head_kwargs: {} # keyword arguments to passed to head
head_classifier_kwargs: # keyword arguments passed to classifier in head
num_convs: 1
norm_channels_per_group: 16
norm_affine: True
reduction: "mean"
loss_weight: 1.
prior_prob: 0.01
head_regressor_kwargs: # keyword arguments passed to regressor in head
num_convs: 1
norm_channels_per_group: 16
norm_affine: True
reduction: "sum"
loss_weight: 1.
learn_scale: True
head_sampler_kwargs: # keyword arguments passed to sampler
batch_size_per_image: 32 # number of anchors sampled per image
positive_fraction: 0.33 # defines ratio between positive and negative anchors
# hard negatives are sampled from a pool of size:
# batch_size_per_image * (1 - positive_fraction) * pool_size
pool_size: 20
min_neg: 1 # minimum number of negative anchors sampled per image
segmenter_kwargs:
dice_kwargs:
batch_dice: True
matcher_kwargs: # keyword arguments passed to matcher
num_candidates: 4
center_in_gt: False
plan_arch_overwrites: {} # overwrite arguments of architecture
plan_anchors_overwrites: {} # overwrite arguments of anchors
from nndet.core.boxes.anchors import (
AnchorGeneratorType,
get_anchor_generator,
compute_anchors_for_strides,
AnchorGenerator2D,
AnchorGenerator2DS,
AnchorGenerator3D,
AnchorGenerator3DS,
)
from nndet.core.boxes.clip import (
clip_boxes_to_image_,
clip_boxes_to_image,
)
from nndet.core.boxes.coder import CoderType, BoxCoderND
from nndet.core.boxes.matcher import MatcherType, Matcher, IoUMatcher, ATSSMatcher
from nndet.core.boxes.nms import nms, batched_nms
from nndet.core.boxes.sampler import AbstractSampler, NegativeSampler, HardNegativeSampler, \
BalancedHardNegativeSampler, HardNegativeSamplerFgAll, HardNegativeSamplerBatched
from nndet.core.boxes.ops import box_area, box_iou, remove_small_boxes, box_center, permute_boxes, \
expand_to_boxes, box_size, generalized_box_iou, box_center_dist, center_in_boxes
from nndet.core.boxes.ops_np import box_iou_np, box_size_np, box_area_np
"""
Parts of this code are from torchvision and thus licensed under
BSD 3-Clause License
Copyright (c) Soumith Chintala 2016,
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import torch
from typing import Callable, Sequence, List, Tuple, TypeVar, Union
from torchvision.models.detection.rpn import AnchorGenerator
from loguru import logger
from itertools import product
AnchorGeneratorType = TypeVar('AnchorGeneratorType', bound=AnchorGenerator)
def get_anchor_generator(dim: int, s_param: bool = False) -> AnchorGenerator:
"""
Get anchor generator class for corresponding dimension
Args:
dim: number of spatial dimensions
s_param: enable size parametrization
Returns:
Callable: class of anchor generator
"""
normal = {2: AnchorGenerator2D, 3: AnchorGenerator3D}
sparam = {2: AnchorGenerator2DS, 3: AnchorGenerator3DS}
if s_param:
return sparam[dim]
else:
return normal[dim]
def compute_anchors_for_strides(anchors: torch.Tensor,
strides: Sequence[Union[Sequence[Union[int, float]], Union[int, float]]],
cat: bool) -> Union[List[torch.Tensor], torch.Tensor]:
"""
Compute anchors sizes which follow a given sequence of strides
Args:
anchors: anchors for stride 0
strides: sequence of strides to adjust anchors for
cat: concatenate resulting anchors, if false a Sequence of Anchors
is returned
Returns:
Union[List[torch.Tensor], torch.Tensor]: new anchors
"""
anchors_with_stride = [anchors]
dim = anchors.shape[1] // 2
for stride in strides:
if isinstance(stride, (int, float)):
stride = [stride] * dim
stride_formatted = [stride[0], stride[1], stride[0], stride[1]]
if dim == 3:
stride_formatted.extend([stride[2], stride[2]])
anchors_with_stride.append(
anchors * torch.tensor(stride_formatted)[None].float())
if cat:
anchors_with_stride = torch.cat(anchors_with_stride, dim=0)
return anchors_with_stride
class AnchorGenerator2D(AnchorGenerator):
def __init__(self, sizes: Sequence[Union[int, Sequence[int]]] = (128, 256, 512),
aspect_ratios: Sequence[Union[float, Sequence[float]]] = (0.5, 1.0, 2.0),
**kwargs):
"""
Generator for anchors
Modified from https://github.com/pytorch/vision/blob/master/torchvision/models/detection/rpn.py
Args:
sizes (Sequence[Union[int, Sequence[int]]]): anchor sizes for each feature map
(length should match the number of feature maps)
aspect_ratios (Sequence[Union[float, Sequence[float]]]): anchor aspect ratios:
height/width, e.g. (0.5, 1, 2). if Seq[Seq] is provided, it should have
the same length as sizes
"""
super().__init__(sizes=sizes, aspect_ratios=aspect_ratios)
self.num_anchors_per_level: List[int] = None
if kwargs:
logger.info(f"Discarding anchor generator kwargs {kwargs}")
def cached_grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[int]]) -> List[torch.Tensor]:
"""
Check if combination was already generated before and return that if possible
Args:
grid_sizes (Sequence[Sequence[int]]): spatial sizes of feature maps
strides (Sequence[Sequence[int]]): stride of each feature map
Returns:
List[torch.Tensor]: Anchors for each feature maps
"""
key = str(grid_sizes + strides)
if key not in self._cache:
self._cache[key] = self.grid_anchors(grid_sizes, strides)
self.num_anchors_per_level = self._cache[key][1]
return self._cache[key][0]
def grid_anchors(self, grid_sizes, strides) -> Tuple[List[torch.Tensor], List[int]]:
"""
Distribute anchors over feature maps
Args:
grid_sizes (Sequence[Sequence[int]]): spatial sizes of feature maps
strides (Sequence[Sequence[int]]): stride of each feature map
Returns:
List[torch.Tensor]: Anchors for each feature maps
List[int]: number of anchors per level
"""
assert len(grid_sizes) == len(strides), "Every fm size needs strides"
assert len(grid_sizes) == len(self.cell_anchors), "Every fm size needs cell anchors"
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
_i = 0
# modified from torchvision (ordering of axis differs)
anchor_per_level = []
for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
size0, size1 = size
stride0, stride1 = stride
device = base_anchors.device
shifts_x = torch.arange(0, size0, dtype=torch.float, device=device) * stride0
shifts_y = torch.arange(0, size1, dtype=torch.float, device=device) * stride1
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
_anchors = (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
anchors.append(_anchors)
anchor_per_level.append(_anchors.shape[0])
logger.debug(f"Generated {anchors[_i].shape[0]} anchors and expected "
f"{size0 * size1 * self.num_anchors_per_location()[_i]} "
f"anchors on level {_i}.")
_i += 1
return anchors, anchor_per_level
@staticmethod
def generate_anchors(scales: Tuple[int],
aspect_ratios: Tuple[float],
dtype: torch.dtype = torch.float,
device: Union[torch.device, str] = "cpu",
) -> torch.Tensor:
"""
Generate anchors for a pair of scales and ratios
Args:
scales (Tuple[int]): scales of anchors, e.g. (32, 64, 128)
aspect_ratios (Tuple[float]): aspect ratios of height/width, e.g. (0.5, 1, 2)
dtype (torch.dtype): data type of anchors
device (Union[torch.device, str]): target device of anchors
Returns:
Tensor: anchors of shape [n(scales) * n(ratios), dim * 2]
"""
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios
ws = (w_ratios[:, None] * scales[None, :]).view(-1)
hs = (h_ratios[:, None] * scales[None, :]).view(-1)
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()
def set_cell_anchors(self, dtype: torch.dtype, device: Union[torch.device, str] = "cpu") -> None:
"""
Set :para:`self.cell_anchors` if it was not already set
Args:
dtype (torch.dtype): data type of anchors
device (Union[torch.device, str]): target device of anchors
Returns:
None
result is saved into attribute
"""
if self.cell_anchors is not None:
return
cell_anchors = [self.generate_anchors(sizes, aspect_ratios, dtype, device)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)]
self.cell_anchors = cell_anchors
def forward(self, image_list: torch.Tensor, feature_maps: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Generate anchors for given feature maps
# TODO: update docstring and type
Args:
image_list (torch.Tensor): data structure which contains images and their original shapes
feature_maps (Sequence[torch.Tensor]): feature maps for which anchors need to be generated
Returns:
List[Tensor]: list of anchors (for each image inside the batch)
"""
device = image_list.device
grid_sizes = list([feature_map.shape[2:] for feature_map in feature_maps])
image_size = image_list.shape[2:]
strides = [list((int(i / s) for i, s in zip(image_size, fm_size))) for fm_size in grid_sizes]
self.set_cell_anchors(dtype=feature_maps[0].dtype, device=feature_maps[0].device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = []
images_shapes = [img.shape for img in image_list.split(1)]
for i, x in enumerate(images_shapes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image).to(device) for anchors_per_image in anchors]
# TODO: check with torchvision if this makes sense (if enabled, anchors are newly generated for each run)
# # Clear the cache in case that memory leaks.
# self._cache.clear()
return anchors
def get_num_acnhors_per_level(self) -> List[int]:
"""
Number of anchors per resolution
Returns:
List[int]: number of anchors per positions for each resolution
"""
if self.num_anchors_per_level is None:
raise RuntimeError("Need to forward features maps before "
"get_num_acnhors_per_level can be called")
return self.num_anchors_per_level
class AnchorGenerator3D(AnchorGenerator2D):
def __init__(self,
sizes: Sequence[Union[int, Sequence[int]]] = (128, 256, 512),
aspect_ratios: Sequence[Union[float, Sequence[float]]] = (0.5, 1.0, 2.0),
zsizes: Sequence[Union[int, Sequence[int]]] = (4, 4, 4),
**kwargs):
"""
Helper to generate anchors for different input sizes
Args:
sizes (Sequence[Union[int, Sequence[int]]]): anchor sizes for each feature map
(length should match the number of feature maps)
aspect_ratios (Sequence[Union[float, Sequence[float]]]): anchor aspect ratios:
height/width, e.g. (0.5, 1, 2). if Seq[Seq] is provided, it should have
the same length as sizes
zsizes (Sequence[Union[int, Sequence[int]]]): sizes along z dimension
"""
super().__init__(sizes, aspect_ratios)
if not isinstance(zsizes[0], (Sequence, list, tuple)):
zsizes = (zsizes,) * len(sizes)
self.zsizes = zsizes
if kwargs:
logger.info(f"Discarding anchor generator kwargs {kwargs}")
def set_cell_anchors(self, dtype: torch.dtype, device: Union[torch.device, str] = "cpu") -> None:
"""
Compute anchors for all pairs of sclaes and ratios and save them inside :param:`cell_anchors`
if they were not computed before
Args:
dtype (torch.dtype): data type of anchors
device (Union[torch.device, str]): target device of anchors
Returns:
None (result is saved into :param:`self.cell_anchors`)
"""
if self.cell_anchors is not None:
return
cell_anchors = [
self.generate_anchors(sizes, aspect_ratios, zsizes, dtype, device)
for sizes, aspect_ratios, zsizes in zip(self.sizes, self.aspect_ratios, self.zsizes)
]
self.cell_anchors = cell_anchors
@staticmethod
def generate_anchors(scales: Tuple[int], aspect_ratios: Tuple[float], zsizes: Tuple[int],
dtype: torch.dtype = torch.float,
device: Union[torch.device, str] = "cpu") -> torch.Tensor:
"""
Generate anchors for a pair of scales and ratios
Args:
scales (Tuple[int]): scales of anchors, e.g. (32, 64, 128)
aspect_ratios (Tuple[float]): aspect ratios of height/width, e.g. (0.5, 1, 2)
zsizes (Tuple[int]): scale along z dimension
dtype (torch.dtype): data type of anchors
device (Union[torch.device, str]): target device of anchors
Returns:
Tensor: anchors of shape [n(scales) * n(ratios) * n(zscales) , dim * 2]
"""
base_anchors_2d = AnchorGenerator2D.generate_anchors(
scales, aspect_ratios, dtype=dtype, device=device)
zanchors = torch.cat(
[torch.as_tensor([-z, z], dtype=dtype, device=device).repeat(
base_anchors_2d.shape[0], 1) for z in zsizes], dim=0)
base_anchors_3d = torch.cat(
[base_anchors_2d.repeat(len(zsizes), 1), (zanchors / 2.).round()], dim=1)
return base_anchors_3d
def grid_anchors(self, grid_sizes: Sequence[Sequence[int]],
strides: Sequence[Sequence[int]]) -> Tuple[List[torch.Tensor], List[int]]:
"""
Distribute anchors over feature maps
Args:
grid_sizes (Sequence[Sequence[int]]): spatial sizes of feature maps
strides (Sequence[Sequence[int]]): stride of each feature map
Returns:
List[torch.Tensor]: Anchors for each feature maps
List[int]: number of anchors per level
"""
assert len(grid_sizes) == len(strides)
assert len(grid_sizes) == len(self.cell_anchors)
anchors = []
_i = 0
anchor_per_level = []
for size, stride, base_anchors in zip(grid_sizes, strides, self.cell_anchors):
size0, size1, size2 = size
stride0, stride1, stride2 = stride
dtype, device = base_anchors.dtype, base_anchors.device
shifts_x = torch.arange(0, size0, dtype=dtype, device=device) * stride0
shifts_y = torch.arange(0, size1, dtype=dtype, device=device) * stride1
shifts_z = torch.arange(0, size2, dtype=dtype, device=device) * stride2
shift_x, shift_y, shift_z = torch.meshgrid(shifts_x, shifts_y, shifts_z)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shift_z = shift_z.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y, shift_z, shift_z), dim=1)
_anchors = (shifts.view(-1, 1, 6) + base_anchors.view(1, -1, 6)).reshape(-1, 6)
anchors.append(_anchors)
anchor_per_level.append(_anchors.shape[0])
logger.debug(f"Generated {_anchors.shape[0]} anchors and expected "
f"{size0 * size1 * size2 * self.num_anchors_per_location()[_i]} "
f"anchors on level {_i}.")
_i += 1
return anchors, anchor_per_level
def num_anchors_per_location(self) -> List[int]:
"""
Number of anchors per resolution
Returns:
List[int]: number of anchors per positions for each resolution
"""
return [len(s) * len(a) * len(z) for s, a, z in zip(self.sizes, self.aspect_ratios, self.zsizes)]
class AnchorGenerator2DS(AnchorGenerator2D):
def __init__(self,
width: Sequence[Union[int, Sequence[int]]],
height: Sequence[Union[int, Sequence[int]]],
**kwargs,
):
"""
Helper to generate anchors for different input sizes
Uses a different parametrization of anchors
(if Sequence[int] is provided it is interpreted as one
value per feature map size)
Args:
width: sizes along width dimension
height: sizes along height dimension
"""
# TODO: check width and height statements
super().__init__()
if not isinstance(width[0], Sequence):
width = [(w,) for w in width]
if not isinstance(height[0], Sequence):
height = [(h,) for h in height]
self.width = width
self.height = height
assert len(self.width) == len(self.height)
if kwargs:
logger.info(f"Discarding anchor generator kwargs {kwargs}")
def set_cell_anchors(self, dtype: torch.dtype,
device: Union[torch.device, str] = "cpu") -> None:
"""
Compute anchors for all pairs of sclaes and ratios and
save them inside :param:`cell_anchors`
if they were not computed before
Args:
dtype (torch.dtype): data type of anchors
device (Union[torch.device, str]): target device of anchors
Returns:
None (result is saved into :param:`self.cell_anchors`)
"""
if self.cell_anchors is not None:
return
cell_anchors = [
self.generate_anchors(w, h, dtype, device)
for w, h in zip(self.width, self.height)
]
self.cell_anchors = cell_anchors
@staticmethod
def generate_anchors(width: Tuple[int],
height: Tuple[int],
dtype: torch.dtype = torch.float,
device: Union[torch.device, str] = "cpu",
) -> torch.Tensor:
"""
Generate anchors for given width, height and depth sizes
Args:
width: sizes along width dimension
height: sizes along height dimension
Returns:
Tensor: anchors of shape [n(width) * n(height), dim * 2]
"""
all_sizes = torch.tensor(list(product(width, height)),
dtype=dtype, device=device) / 2
anchors = torch.stack([-all_sizes[:, 0], -all_sizes[:, 1],
all_sizes[:, 0], all_sizes[:, 1]], dim=1)
return anchors
def num_anchors_per_location(self) -> List[int]:
"""
Number of anchors per resolution
Returns:
List[int]: number of anchors per positions for each resolution
"""
return [len(w) * len(h) for w, h in zip(self.width, self.height)]
class AnchorGenerator3DS(AnchorGenerator3D):
def __init__(self,
width: Sequence[Union[int, Sequence[int]]],
height: Sequence[Union[int, Sequence[int]]],
depth: Sequence[Union[int, Sequence[int]]],
**kwargs,
):
"""
Helper to generate anchors for different input sizes
Uses a different parametrization of anchors
(if Sequence[int] is provided it is interpreted as one
value per feature map size)
Args:
width: sizes along width dimension
height: sizes along height dimension
depth: sizes along depth dimension
"""
# TODO: check width and height statements
super().__init__()
if not isinstance(width[0], Sequence):
width = [(w,) for w in width]
if not isinstance(height[0], Sequence):
height = [(h,) for h in height]
if not isinstance(depth[0], Sequence):
depth = [(d,) for d in depth]
self.width = width
self.height = height
self.depth = depth
assert len(self.width) == len(self.height) == len(self.depth)
if kwargs:
logger.info(f"Discarding anchor generator kwargs {kwargs}")
def set_cell_anchors(self, dtype: torch.dtype, device: Union[torch.device, str] = "cpu") -> None:
"""
Compute anchors for all pairs of scales and ratios and save them inside :param:`cell_anchors`
if they were not computed before
Args:
dtype (torch.dtype): data type of anchors
device (Union[torch.device, str]): target device of anchors
Returns:
None (result is saved into :param:`self.cell_anchors`)
"""
if self.cell_anchors is not None:
return
cell_anchors = [
self.generate_anchors(w, h, d, dtype, device)
for w, h, d in zip(self.width, self.height, self.depth)
]
self.cell_anchors = cell_anchors
@staticmethod
def generate_anchors(width: Tuple[int],
height: Tuple[int],
depth: Tuple[int],
dtype: torch.dtype = torch.float,
device: Union[torch.device, str] = "cpu") -> torch.Tensor:
"""
Generate anchors for given width, height and depth sizes
Args:
width: sizes along width dimension
height: sizes along height dimension
depth: sizes along depth dimension
Returns:
Tensor: anchors of shape [n(width) * n(height) * n(depth) , dim * 2]
"""
all_sizes = torch.tensor(list(product(width, height, depth)),
dtype=dtype, device=device) / 2
anchors = torch.stack(
[-all_sizes[:, 0], -all_sizes[:, 1], all_sizes[:, 0], all_sizes[:, 1],
-all_sizes[:, 2], all_sizes[:, 2]], dim=1
)
return anchors
def num_anchors_per_location(self) -> List[int]:
"""
Number of anchors per resolution
Returns:
List[int]: number of anchors per positions for each resolution
"""
return [len(w) * len(h) * len(d)
for w, h, d in zip(self.width, self.height, self.depth)]
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Tuple
import torch
def clip_boxes_to_image_(boxes: torch.Tensor, img_shape: Tuple[int]):
"""
Clip boxes to image dimensions inplace
Args:
boxes (Tensor): tensor with boxes [N x (2*dim)] (x_min, y_min, x_max, y_max(, z_min, z_max))
img_shape (Tuple[height, width(, depth)]): size of image
Returns:
Tensor: clipped boxes as tensor
Raises:
ValueError: boxes need to have 4(2D) or 6(3D) components
"""
if boxes.shape[-1] == 4:
return clip_boxes_to_image_2d_(boxes, img_shape)
elif boxes.shape[-1] == 6:
return clip_boxes_to_image_3d_(boxes, img_shape)
else:
raise ValueError(f"Boxes with {boxes.shape[-1]} are not supported.")
def clip_boxes_to_image(boxes: torch.Tensor, img_shape: Tuple[int]):
"""
Clip boxes to image dimensions
Args:
boxes (Tensor): tensor with boxes [N x (2*dim)] (x_min, y_min, x_max, y_max(, z_min, z_max))
img_shape (Tuple[height, width(, depth)]): size of image
Returns:
Tensor: clipped boxes as tensor
Raises:
ValueError: boxes need to have 4(2D) or 6(3D) components
"""
if boxes.shape[-1] == 4:
return clip_boxes_to_image_2d(boxes, img_shape)
elif boxes.shape[-1] == 6:
return clip_boxes_to_image_3d(boxes, img_shape)
else:
raise ValueError(f"Boxes with {boxes.shape[-1]} are not supported.")
def clip_boxes_to_image_2d_(boxes: torch.Tensor, img_shape: Tuple[int, int]):
"""
Clip boxes to image dimensions
Args:
boxes (Tensor): tensor with boxes [N x 4] (x_min, y_min, x_max, y_max)
img_shape (Tuple[x_max, y_max]): size of image
Returns:
Tensor: clipped boxes as tensor
"""
s0, s1 = img_shape
boxes[..., 0::2].clamp_(min=0, max=s0)
boxes[..., 1::2].clamp_(min=0, max=s1)
return boxes
def clip_boxes_to_image_3d_(boxes: torch.Tensor, img_shape: Tuple[int, int, int]):
"""
Clip boxes to image dimensions
Args:
boxes (Tensor): tensor with boxes [N x 6] (x_min, y_min, x_max, y_max, z_min, z_max)
img_shape (Tuple[height, width, depth]): size of image
Returns:
Tensor: clipped boxes as tensor
"""
s0, s1, s2 = img_shape
boxes[..., 0::6].clamp_(min=0, max=s0)
boxes[..., 1::6].clamp_(min=0, max=s1)
boxes[..., 2::6].clamp_(min=0, max=s0)
boxes[..., 3::6].clamp_(min=0, max=s1)
boxes[..., 4::6].clamp_(min=0, max=s2)
boxes[..., 5::6].clamp_(min=0, max=s2)
return boxes
def clip_boxes_to_image_2d(boxes: torch.Tensor, img_shape: Tuple[int, int]):
"""
Clip boxes to image dimensions
Args:
boxes (Tensor): tensor with boxes [N x 4] (x_min, y_min, x_max, y_max)
img_shape (Tuple[x_max, y_max]): size of image
Returns:
Tensor: clipped boxes as tensor
Notes:
Uses float32 internally because clipping of half cpu tensors is not
supported
"""
s0, s1 = img_shape
boxes[..., 0::2] = boxes[..., 0::2].clamp(min=0, max=s0)
boxes[..., 1::2] = boxes[..., 1::2].clamp(min=0, max=s1)
return boxes
def clip_boxes_to_image_3d(boxes: torch.Tensor, img_shape: Tuple[int, int, int]):
"""
Clip boxes to image dimensions
Args:
boxes (Tensor): tensor with boxes [N x 6] (x_min, y_min, x_max, y_max, z_min, z_max)
img_shape (Tuple[height, width, depth]): size of image
Returns:
Tensor: clipped boxes as tensor
Notes:
Uses float32 internally because clipping of half cpu tensors is not
supported
"""
s0, s1, s2 = img_shape
boxes[..., 0::6] = boxes[..., 0::6].clamp(min=0, max=s0)
boxes[..., 1::6] = boxes[..., 1::6].clamp(min=0, max=s1)
boxes[..., 2::6] = boxes[..., 2::6].clamp(min=0, max=s0)
boxes[..., 3::6] = boxes[..., 3::6].clamp(min=0, max=s1)
boxes[..., 4::6] = boxes[..., 4::6].clamp(min=0, max=s2)
boxes[..., 5::6] = boxes[..., 5::6].clamp(min=0, max=s2)
return boxes
"""
Parts of this code are from torchvision and thus licensed under
BSD 3-Clause License
Copyright (c) Soumith Chintala 2016,
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from __future__ import division
import math
from typing import Sequence, TypeVar
import torch
from torch.jit.annotations import List, Tuple
from torch import Tensor
from torchvision.models.detection._utils import BoxCoder
@torch.jit.script
def encode_boxes(reference_boxes: torch.Tensor,
proposals: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
"""
Encode a set of proposals with respect to some reference boxes
Args:
reference_boxes: reference boxes (x1, y1, x2, y2, (z1, z2))
proposals: boxes to be encoded (x1, y1, x2, y2, (z1, z2))
weights: weights for dimensions (wx, wy, ww, wh, wz, wd)
"""
# perform some unpacking to make it JIT-fusion friendly
wx = weights[0]
wy = weights[1]
ww = weights[2]
wh = weights[3]
proposals_x1 = proposals[:, 0].unsqueeze(1)
proposals_y1 = proposals[:, 1].unsqueeze(1)
proposals_x2 = proposals[:, 2].unsqueeze(1)
proposals_y2 = proposals[:, 3].unsqueeze(1)
reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)
# implementation starts here
ex_widths = proposals_x2 - proposals_x1
ex_heights = proposals_y2 - proposals_y1
ex_ctr_x = proposals_x1 + 0.5 * ex_widths
ex_ctr_y = proposals_y1 + 0.5 * ex_heights
gt_widths = reference_boxes_x2 - reference_boxes_x1
gt_heights = reference_boxes_y2 - reference_boxes_y1
gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights
targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
targets_dw = ww * torch.log(gt_widths / ex_widths)
targets_dh = wh * torch.log(gt_heights / ex_heights)
if proposals.shape[1] == 6:
wz = weights[4]
wd = weights[5]
proposals_z1 = proposals[:, 4].unsqueeze(1)
proposals_z2 = proposals[:, 5].unsqueeze(1)
ex_depth = proposals_z2 - proposals_z1
ex_ctr_z = proposals_z1 + 0.5 * ex_depth
reference_boxes_z1 = reference_boxes[:, 4].unsqueeze(1)
reference_boxes_z2 = reference_boxes[:, 5].unsqueeze(1)
gt_depth = reference_boxes_z2 - reference_boxes_z1
gt_ctr_z = reference_boxes_z1 + 0.5 * gt_depth
targets_dz = wz * (gt_ctr_z - ex_ctr_z) / ex_depth
targets_dd = wd * torch.log(gt_depth / ex_depth)
targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh,
targets_dz, targets_dd), dim=1)
else:
targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
return targets
def decode_single(rel_codes: Tensor, boxes: Tensor,
weights: Sequence[float],
bbox_xform_clip: float) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Args:
rel_codes: encoded boxes [Num_boxes x (dim * 2)] (dx, dy, dw, dh, dz, dd)
boxes: reference boxes (x1, y1, x2, y2, (z1, z2))
"""
# offset is 4 in case of 2d data and 6 in case of 3d
offset = boxes.shape[1]
boxes = boxes.to(rel_codes.dtype)
widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1]
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights
wx = weights[0]
wy = weights[1]
ww = weights[2]
wh = weights[3]
dx = rel_codes[:, 0::offset] / wx
dy = rel_codes[:, 1::offset] / wy
dw = rel_codes[:, 2::offset] / ww
dh = rel_codes[:, 3::offset] / wh
# Prevent sending too large values into torch.exp()
dw = torch.clamp(dw, max=bbox_xform_clip)
dh = torch.clamp(dh, max=bbox_xform_clip)
pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
pred_w = torch.exp(dw) * widths[:, None]
pred_h = torch.exp(dh) * heights[:, None]
pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype) * pred_w
pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype) * pred_h
if offset == 6:
depths = boxes[:, 5] - boxes[:, 4]
ctr_z = boxes[:, 4] + 0.5 * depths
wz = weights[4]
wd = weights[5]
dz = rel_codes[:, 4::offset] / wz
dd = rel_codes[:, 5::offset] / wd
dd = torch.clamp(dd, max=bbox_xform_clip)
pred_ctr_z = dz * depths[:, None] + ctr_z[:, None]
pred_z = torch.exp(dd) * depths[:, None]
pred_boxes5 = pred_ctr_z - torch.tensor(0.5, dtype=pred_ctr_z.dtype) * pred_z
pred_boxes6 = pred_ctr_z + torch.tensor(0.5, dtype=pred_ctr_z.dtype) * pred_z
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4,
pred_boxes5, pred_boxes6), dim=2).flatten(1)
else:
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4),
dim=2).flatten(1)
return pred_boxes
class BoxCoderND(BoxCoder):
"""
This class encodes and decodes a set of bounding boxes into
the representation used for training the regressors.
Compatible with 2d and 3d
"""
def encode(self,
reference_boxes: List[Tensor],
proposals: List[Tensor],
) -> Tuple[Tensor]:
"""
Encode a set of proposals with respect to some reference boxes
Args:
reference_boxes: reference boxes for each image.
(x1, y1, x2, y2, (z1, z2))
proposals: proposals for each image
(x1, y1, x2, y2, (z1, z2))
Returns:
Tuple[Tensor]: regression targets for each image
"""
# filter for images which have a foreground class
filter_min_one_gt = [rb.numel() > 0 for rb in reference_boxes]
filtered_ref_boxes = [
rb for idx, rb in enumerate(reference_boxes) if filter_min_one_gt[idx]]
filtered_proposals = [
pr for idx, pr in enumerate(proposals) if filter_min_one_gt[idx]]
if any(filter_min_one_gt):
filtered_encoded = super().encode(filtered_ref_boxes, filtered_proposals)
# fill image with no ground truth
idx_enc = 0
encoded = []
for img_idx, gt_present in enumerate(filter_min_one_gt):
if gt_present:
encoded.append(filtered_encoded[idx_enc])
idx_enc += 1
else:
# fill with zeros because they do not contribute to the
# regression loss anyway (all anchors are labeled as background)
encoded.append(torch.zeros_like(proposals[img_idx]))
return encoded
def encode_single(self,
reference_boxes: Tensor,
proposals: Tensor,
) -> Tensor:
"""
Encode a set of proposals with respect to some reference boxes
Arguments:
reference_boxes: reference boxes (x1, y1, x2, y2, (z1, z2))
proposals: boxes to be encoded (x1, y1, x2, y2, (z1, z2))
"""
dtype, device = reference_boxes.dtype, reference_boxes.device
weights = torch.tensor(self.weights, dtype=dtype, device=device)
targets = encode_boxes(reference_boxes, proposals, weights)
return targets
def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
"""
Decode boxes
Args:
rel_codes: relative offsets to reference boxes
(dx, dy, dw, dh, (dz, dd))[N, dim * 2]
boxes: list of reference boxes per image
(x1, y1, x2, y2, (z1, z2))
Returns:
Tensor: decoded boxes
"""
assert isinstance(boxes, (list, tuple))
assert isinstance(rel_codes, torch.Tensor)
boxes_per_image = [b.size(0) for b in boxes]
concat_boxes = torch.cat(boxes, dim=0)
spatial_dims = concat_boxes.shape[1]
box_sum = 0
for val in boxes_per_image:
box_sum += val
pred_boxes = self.decode_single(rel_codes.reshape(box_sum, -1), concat_boxes)
return pred_boxes.reshape(box_sum, spatial_dims)
def decode_single(self, rel_codes: torch.Tensor, boxes: torch.Tensor):
dtype, device = rel_codes.dtype, rel_codes.device
return decode_single(rel_codes, boxes, self.weights, self.bbox_xform_clip)
CoderType = TypeVar('CoderType', bound=BoxCoderND)
"""
Parts of this code are from torchvision and thus licensed under
BSD 3-Clause License
Copyright (c) Soumith Chintala 2016,
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from typing import Sequence, Callable, Tuple, TypeVar
from abc import ABC
import torch
from torch import Tensor
from loguru import logger
from nndet.core.boxes.ops import box_iou, box_center_dist, center_in_boxes
INF = 100 # not really inv but here it is sufficient
class Matcher(ABC):
BELOW_LOW_THRESHOLD: int = -1
BETWEEN_THRESHOLDS: int = -2
def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou):
"""
Matches boxes and anchors to each other
Args:
similarity_fn: function for similarity computation between
boxes and anchors
"""
self.similarity_fn = similarity_fn
def __call__(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches for a single image
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
if boxes.numel() == 0:
# no ground truth
num_anchors = anchors.shape[0]
match_quality_matrix = torch.tensor([]).to(anchors)
matches = torch.empty(num_anchors, dtype=torch.int64).fill_(self.BELOW_LOW_THRESHOLD)
return match_quality_matrix, matches
else:
# at least one ground truth
return self.compute_matches(
boxes=boxes, anchors=anchors,
num_anchors_per_level=num_anchors_per_level,
num_anchors_per_loc=num_anchors_per_loc,
)
def compute_matches(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
raise NotImplementedError
class IoUMatcher(Matcher):
def __init__(self,
low_threshold: float,
high_threshold: float,
allow_low_quality_matches: bool,
similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou):
"""
Compute IoU based matching for a single image
Args:
low_threshold: threshold used to assign background values
high_threshold: threshold used to assign foreground values
allow_low_quality_matches: if enabled, anchors with not
match get the box with highest IoU assigned
similarity_fn: function for similarity computation between
boxes and anchors
"""
super().__init__(similarity_fn=similarity_fn)
assert low_threshold <= high_threshold
self.high_threshold = high_threshold
self.low_threshold = low_threshold
self.allow_low_quality_matches = allow_low_quality_matches
def compute_matches(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches according to given iou thresholds
Adapted from
(https://github.com/pytorch/vision/blob/c7c2085ec686ccc55e1df85736b240b24
05d1179/torchvision/models/detection/_utils.py)
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors_per_level: number of anchors per feature pyramid level
anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
match_quality_matrix = self.similarity_fn(boxes, anchors)
# match_quality_matrix is M (gt) x N (anchors)
# Max over gt elements (dim 0) to find best gt candidate for each anchor
matched_vals, matches = match_quality_matrix.max(dim=0)
# _v, _i = matched_vals.topk(5)
# print(boxes, _v, anchors[_i])
if self.allow_low_quality_matches:
all_matches = matches.clone()
# Assign candidate matches with low quality to negative (unassigned) values
below_low_threshold = matched_vals < self.low_threshold
between_thresholds = (matched_vals >= self.low_threshold) & (
matched_vals < self.high_threshold
)
matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
matches[between_thresholds] = self.BETWEEN_THRESHOLDS
if self.allow_low_quality_matches:
matches = self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)
# self._debug_logging(match_quality_matrix, matches, matched_vals,
# below_low_threshold, between_thresholds)
return match_quality_matrix, matches
def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
"""
Find the best matching prediction for each bounding box
regardless of its IoU (this implementation excludes ties!)
Args:
matches: matched anchors to background and in between
all_matches: all matches regardless of IoU
match_quality_matrix: [M,N] tensor of IoUs (GroundTruth x NumAnchors)
"""
# For each gt, find the prediction with has highest quality
_, best_pred_idx = match_quality_matrix.max(dim=1) # [M]
matches[best_pred_idx] = torch.arange(len(best_pred_idx)).to(matches)
return matches
@staticmethod
def _debug_logging(match_quality_matrix, matches, matched_vals,
below_low_threshold, between_thresholds):
logger.info("########## Matcher ##############")
logger.info(f"Max IoU: {match_quality_matrix.max()}")
logger.info(f"Foreground IoUs: {matched_vals[matches > -1]}")
logger.info(f"Num GT: {match_quality_matrix.shape[0]}")
match_bet_min = matched_vals[between_thresholds].min() if \
matched_vals[between_thresholds].nelement() > 0 else None
match_bet_max = matched_vals[between_thresholds].max() if \
matched_vals[between_thresholds].nelement() > 0 else None
logger.info(f"Inbetween IoU ranging from {match_bet_min} to {match_bet_max}")
logger.info(f"Max background IoU: {matched_vals[below_low_threshold].max()}")
logger.info("#################################")
class ATSSMatcher(Matcher):
def __init__(self,
num_candidates: int,
similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou,
center_in_gt: bool = True,
):
"""
Compute matching based on ATSS
https://arxiv.org/abs/1912.02424
`Bridging the Gap Between Anchor-based and Anchor-free Detection
via Adaptive Training Sample Selection`
Args:
num_candidates: number of positions to select candidates from
similarity_fn: function for similarity computation between
boxes and anchors
center_in_gt: If diabled, matched anchor center points do not need
to lie withing the ground truth box.
"""
super().__init__(similarity_fn=similarity_fn)
self.num_candidates = num_candidates
self.min_dist = 0.01
self.center_in_gt = center_in_gt
logger.info(f"Running ATSS Matching with num_candidates={self.num_candidates} "
f"and center_in_gt {self.center_in_gt}.")
def compute_matches(self,
boxes: torch.Tensor,
anchors: torch.Tensor,
num_anchors_per_level: Sequence[int],
num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute matches according to ATTS for a single image
Adapted from
(https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss
/loss.py#L180-L184)
Args:
boxes: anchors are matches to these boxes (e.g. ground truth)
[N, dims * 2](x1, y1, x2, y2, (z1, z2))
anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
num_anchors_per_level: number of anchors per feature pyramid level
num_anchors_per_loc: number of anchors per position
Returns:
Tensor: matrix which contains the similarity from each boxes
to each anchor [N, M]
Tensor: vector which contains the matched box index for all
anchors (if background `BELOW_LOW_THRESHOLD` is used
and if it should be ignored `BETWEEN_THRESHOLDS` is used)
[M]
"""
num_gt = boxes.shape[0]
num_anchors = anchors.shape[0]
distances, boxes_center, anchors_center = box_center_dist(boxes, anchors) # num_boxes x anchors
# select candidates based on center distance
candidate_idx = []
start_idx = 0
for level, apl in enumerate(num_anchors_per_level):
end_idx = start_idx + apl
topk = min(self.num_candidates * num_anchors_per_loc, apl)
_, idx = distances[:, start_idx: end_idx].topk(topk, dim=1, largest=False)
# idx shape [num_boxes x topk]
candidate_idx.append(idx + start_idx)
start_idx = end_idx
# [num_boxes x num_candidates] (index of candidate anchors)
candidate_idx = torch.cat(candidate_idx, dim=1)
match_quality_matrix = self.similarity_fn(boxes, anchors) # [num_boxes x anchors]
candidate_ious = match_quality_matrix.gather(1, candidate_idx) # [num_boxes, n_candidates]
# compute adaptive iou threshold
iou_mean_per_gt = candidate_ious.mean(dim=1) # [num_boxes]
iou_std_per_gt = candidate_ious.std(dim=1) # [num_boxes]
iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt # [num_boxes]
is_pos = candidate_ious >= iou_thresh_per_gt[:, None] # [num_boxes x n_candidates]
if self.center_in_gt: # can discard all candidates in case of very small objects :/
# center point of selected anchors needs to lie within the ground truth
boxes_idx = torch.arange(num_gt, device=boxes.device, dtype=torch.long)[:, None]\
.expand_as(candidate_idx).contiguous() # [num_boxes x n_candidates]
is_in_gt = center_in_boxes(
anchors_center[candidate_idx.view(-1)], boxes[boxes_idx.view(-1)], eps=self.min_dist)
is_pos = is_pos & is_in_gt.view_as(is_pos) # [num_boxes x n_candidates]
# in case on anchor is assigned to multiple boxes, use box with highest IoU
# TODO: think about a better way to do this
for ng in range(num_gt):
candidate_idx[ng, :] += ng * num_anchors
ious_inf = torch.full_like(match_quality_matrix, -INF).view(-1)
index = candidate_idx.view(-1)[is_pos.view(-1)]
ious_inf[index] = match_quality_matrix.view(-1)[index]
ious_inf = ious_inf.view_as(match_quality_matrix)
matched_vals, matches = ious_inf.max(dim=0)
matches[matched_vals == -INF] = self.BELOW_LOW_THRESHOLD
# print(f"Num matches {(matches >= 0).sum()}, Adapt IoU {iou_thresh_per_gt}")
return match_quality_matrix, matches
MatcherType = TypeVar('MatcherType', bound=Matcher)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from torch import Tensor
from torch.cuda.amp import autocast
from torchvision.ops.boxes import nms as nms_2d
from nndet._C import nms as nms_gpu
from nndet.core.boxes.ops import box_iou
def nms_cpu(boxes, scores, thresh):
"""
Performs non-maximum suppression for 3d boxes on cpu
Args:
boxes (Tensor): tensor with boxes (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
scores (Tensor): score for each box [N]
iou_threshold (float): threshould when boxes are discarded
Returns:
keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS,
sorted in decreasing order of scores
"""
ious = box_iou(boxes, boxes)
_, _idx = torch.sort(scores, descending=True)
keep = []
while _idx.nelement() > 0:
keep.append(_idx[0])
# get all elements that were not matched and discard all others.
non_matches = torch.where((ious[_idx[0]][_idx] <= thresh))[0]
_idx = _idx[non_matches]
return torch.tensor(keep).to(boxes).long()
@autocast(enabled=False)
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float):
"""
Performs non-maximum suppression
Args:
boxes (Tensor): tensor with boxes (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
scores (Tensor): score for each box [N]
iou_threshold (float): threshould when boxes are discarded
Returns:
keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS,
sorted in decreasing order of scores
"""
if boxes.shape[1] == 4:
# prefer torchvision in 2d because they have c++ cpu version
nms_fn = nms_2d
else:
if boxes.is_cuda:
nms_fn = nms_gpu
else:
nms_fn = nms_cpu
return nms_fn(boxes.float(), scores.float(), iou_threshold)
def batched_nms(boxes: Tensor, scores: Tensor, idxs: Tensor, iou_threshold: float):
"""
Performs non-maximum suppression in a batched fashion.
Each index value correspond to a category, and NMS
will not be applied between elements of different categories.
Args:
boxes (Tensor): boxes where NMS will be performed. (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
scores (Tensor): scores for each one of the boxes [N]
idxs (Tensor): indices of the categories for each one of the boxes. [N]
iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold
Returns
keep (Tensor): int64 tensor with the indices of the elements that have been kept by NMS,
sorted in decreasing order of scores
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
# strategy: in order to perform NMS independently per class.
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
# from different classes do not overlap
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + 1)
boxes_for_nms = boxes + offsets[:, None]
return nms(boxes_for_nms, scores, iou_threshold)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from torch import Tensor
from numpy import ndarray
from typing import Union, Sequence, Tuple
from torch.cuda.amp import autocast
def box_area_3d(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its
(x1, y1, x2, y2, z1, z2) coordinates.
Arguments:
boxes (Union[Tensor, ndarray]): boxes for which the area will be computed. They
are expected to be in (x1, y1, x2, y2, z1, z2) format. [N, 6]
Returns:
area (Union[Tensor, ndarray]): area for each box [N]
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4])
def box_area_2d(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by its
(x1, y1, x2, y2) coordinates.
Arguments:
boxes (Union[Tensor, ndarray]): boxes for which the area will be computed. They
are expected to be in (x1, y1, x2, y2) format. [N, 4]
Returns:
area (Union[Tensor, ndarray]): area for each box [N]
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
def box_area(boxes: Union[Tensor, ndarray]) -> Union[Tensor, ndarray]:
"""
Computes the area of a set of bounding boxes
Args:
boxes (Union[Tensor, ndarray]): boxes of shape; (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
Returns:
Union[Tensor, ndarray]: area of boxes
See Also:
:func:`box_area_3d`, :func:`torchvision.ops.boxes.box_area`
"""
if boxes.shape[-1] == 4:
return box_area_2d(boxes)
else:
return box_area_3d(boxes)
@autocast(enabled=False)
def box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tensor:
"""
Return intersection-over-union (Jaccard index) of boxes.
(Works for Tensors and Numpy Arrays)
Arguments:
boxes1: boxes; (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
boxes2: boxes; (x1, y1, x2, y2, (z1, z2))[M, dim * 2]
eps: optional small constant for numerical stability
Returns:
iou (Tensor): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2; [N, M]
See Also:
:func:`box_iou_3d`, :func:`torchvision.ops.boxes.box_iou`
Notes:
Need to compute IoU in float32 (autocast=False) because the
volume/area can be to large
"""
# TODO: think about adding additional assert statements to check coordinates x1 <= x2, y1 <= y2, z1 <= z2
if boxes1.numel() == 0 or boxes2.numel() == 0:
return torch.tensor([]).to(boxes1)
if boxes1.shape[-1] == 4:
return box_iou_union_2d(boxes1.float(), boxes2.float(), eps=eps)[0]
else:
return box_iou_union_3d(boxes1.float(), boxes2.float(), eps=eps)[0]
@autocast(enabled=False)
def generalized_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tensor:
"""
Generalized box iou
Arguments:
boxes1: boxes; (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
boxes2: boxes; (x1, y1, x2, y2, (z1, z2))[M, dim * 2]
eps: optional small constant for numerical stability
Returns:
Tensor: the NxM matrix containing the pairwise
generalized IoU values for every element in boxes1 and boxes2; [N, M]
Notes:
Need to compute IoU in float32 (autocast=False) because the
volume/area can be to large
"""
if boxes1.nelement() == 0 or boxes2.nelement() == 0:
return torch.tensor([]).to(boxes1)
if boxes1.shape[-1] == 4:
return generalized_box_iou_2d(boxes1.float(), boxes2.float(), eps=eps)
else:
return generalized_box_iou_3d(boxes1.float(), boxes2.float(), eps=eps)
def box_iou_union_3d(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tuple[Tensor, Tensor]:
"""
Return intersection-over-union (Jaccard index) and of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2, z1, z2) format.
Args:
boxes1: set of boxes (x1, y1, x2, y2, z1, z2)[N, 6]
boxes2: set of boxes (x1, y1, x2, y2, z1, z2)[M, 6]
eps: optional small constant for numerical stability
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
Tensor[N, M]: the nxM matrix containing the pairwise union
values
"""
vol1 = box_area_3d(boxes1)
vol2 = box_area_3d(boxes2)
x1 = torch.max(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = torch.max(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = torch.min(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = torch.min(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
z1 = torch.max(boxes1[:, None, 4], boxes2[:, 4]) # [N, M]
z2 = torch.min(boxes1[:, None, 5], boxes2[:, 5]) # [N, M]
inter = ((x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0) * (z2 - z1).clamp(min=0)) + eps # [N, M]
union = (vol1[:, None] + vol2 - inter)
return inter / union, union
def generalized_box_iou_3d(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tensor:
"""
Computes the generalized box iou between given bounding boxes
Args:
boxes1: set of boxes (x1, y1, x2, y2, z1, z2)[N, 6]
boxes2: set of boxes (x1, y1, x2, y2, z1, z2)[M, 6]
eps: optional small constant for numerical stability
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise
generalized IoU values for every element in boxes1 and boxes2
"""
iou, union = box_iou_union_3d(boxes1, boxes2)
x1 = torch.min(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = torch.min(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = torch.max(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = torch.max(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
z1 = torch.min(boxes1[:, None, 4], boxes2[:, 4]) # [N, M]
z2 = torch.max(boxes1[:, None, 5], boxes2[:, 5]) # [N, M]
vol = ((x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0) * (z2 - z1).clamp(min=0)) + eps # [N, M]
return iou - (vol - union) / vol
def box_iou_union_2d(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tuple[Tensor, Tensor]:
"""
Return intersection-over-union (Jaccard index) and of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
boxes1: set of boxes (x1, y1, x2, y2)[N, 4]
boxes2: set of boxes (x1, y1, x2, y2)[M, 4]
eps: optional small constant for numerical stability
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
union (Tensor[N, M]): the nxM matrix containing the pairwise union
values
"""
area1 = box_area(boxes1)
area2 = box_area(boxes2)
x1 = torch.max(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = torch.max(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = torch.min(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = torch.min(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
inter = ((x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)) + eps # [N, M]
union = (area1[:, None] + area2 - inter)
return inter / union, union
def generalized_box_iou_2d(boxes1: Tensor, boxes2: Tensor, eps: float = 0) -> Tensor:
"""
Computes the generalized box iou between given bounding boxes
Args:
boxes1: set of boxes (x1, y1, x2, y2)[N, 4]
boxes2: set of boxes (x1, y1, x2, y2)[M, 4]
eps: optional small constant for numerical stability
Returns:
Tensor[N, M]: the NxM matrix containing the pairwise
generalized IoU values for every element in boxes1 and boxes2
"""
iou, union = box_iou_union_2d(boxes1, boxes2)
x1 = torch.min(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = torch.min(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = torch.max(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = torch.max(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
area = ((x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)) + eps # [N, M]
return iou - (area - union) / area
def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
"""
Remove boxes with at least one side smaller than min_size.
Arguments:
boxes (Tensor): boxes (x1, y1, x2, y2, (z1, z2)) [N, dim * 2]
min_size (float): minimum size
Returns:
keep (Tensor): indices of the boxes that have both sides
larger than min_size [N]
"""
if boxes.shape[1] == 4:
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
keep = (ws >= min_size) & (hs >= min_size)
else:
ws, hs, ds = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1], boxes[:, 5] - boxes[:, 4]
keep = (ws >= min_size) & (hs >= min_size) & (ds >= min_size)
keep = torch.where(keep)[0]
return keep
def box_center_dist(boxes1: Tensor, boxes2: Tensor, euclidean: bool = True) -> \
Tuple[Tensor, Tensor, Tensor]:
"""
Distance of center points between two sets of boxes
Arguments:
boxes1: boxes; (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
boxes2: boxes; (x1, y1, x2, y2, (z1, z2))[M, dim * 2]
euclidean: computed the euclidean distance otherwise it uses the l1
distance
Returns:
Tensor: the NxM matrix containing the pairwise
distances for every element in boxes1 and boxes2; [N, M]
Tensor: center points of boxes1
Tensor: center points of boxes2
"""
center1 = box_center(boxes1) # [N, dims]
center2 = box_center(boxes2) # [M, dims]
if euclidean:
dists = (center1[:, None] - center2[None]).pow(2).sum(-1).sqrt()
else:
# before sum: [N, M, dims]
dists = (center1[:, None] - center2[None]).sum(-1)
return dists, center1, center2
def center_in_boxes(center: Tensor, boxes: Tensor, eps: float = 0.01) -> Tensor:
"""
Checks which center points are within boxes
Args:
center: center points [N, dims]
boxes: boxes [N, dims * 2]
eps: minimum distance to boarder of boxes
Returns:
Tensor: boolean array indicating which center points are within
the boxes [N]
"""
axes = []
axes.append(center[:, 0] - boxes[:, 0])
axes.append(center[:, 1] - boxes[:, 1])
axes.append(boxes[:, 2] - center[:, 0])
axes.append(boxes[:, 3] - center[:, 1])
if center.shape[1] == 3:
axes.append(center[:, 2] - boxes[:, 4])
axes.append(boxes[:, 5] - center[:, 2])
return torch.stack(axes, dim=1).min(dim=1)[0] > eps
def box_center(boxes: Tensor) -> Tensor:
"""
Compute center point of boxes
Args:
boxes: bounding boxes (x1, y1, x2, y2, (z1, z2)) [N, dims * 2]
Returns:
Tensor: center points [N, dims]
"""
centers = [(boxes[:, 2] + boxes[:, 0]) / 2., (boxes[:, 3] + boxes[:, 1]) / 2.]
if boxes.shape[1] == 6:
centers.append((boxes[:, 5] + boxes[:, 4]) / 2.)
return torch.stack(centers, dim=1)
def permute_boxes(boxes: Union[Tensor, ndarray],
dims: Sequence[int] = None) -> Union[Tensor, ndarray]:
"""
Change ordering of axis of boxes
Args:
boxes: boxes [N, dims * 2](x1, y1, x2, y2(, z1, z2))
dims: the desired ordering of dimensions; By default the dimensions
are reversed
Returns:
Tensor: boxes with permuted axes [N, dims * 2]
"""
if dims is None:
dims = list(range(boxes.shape[1] // 2))[::-1]
if 2 * len(dims) != boxes.shape[1]:
raise TypeError(f"Need same number of dimensions, found dims {dims} "
f"but boxes with shape {boxes.shape}")
indexing = [[0, 2], [1, 3]]
if boxes.shape[1] == 6:
indexing.append([4, 5])
new_axis = [indexing[dims[0]][0], indexing[dims[1]][0],
indexing[dims[0]][1], indexing[dims[1]][1]]
for d in dims[2:]:
new_axis.extend(indexing[d])
return boxes[:, new_axis]
def expand_to_boxes(data: Union[Tensor, ndarray]) -> Union[Tensor, ndarray]:
"""
Expand x,y,z data to box format
Args:
data (Tensor): data to expand (N, dim)[:, (x, y, [z])]
Returns:
Tensor: expanded tensors
"""
idx = [0, 1, 0, 1]
if (len(data.shape) == 1 and data.shape[0] == 3) or (len(data.shape) == 2 and data.shape[1] == 3):
idx.extend((2, 2))
if len(data.shape) == 1:
data = data[None]
return data[:, idx]
def box_size(boxes: Tensor) -> Tensor:
"""
Compute length of boxes along all dimensions
Args:
boxes (Tensor): boxes (x1, y1, x2, y2, z1, z2)[N, dim * 2]
Returns:
Tensor: size along axis (x, y, (z))[N, dim]
"""
dists = []
dists.append(boxes[:, 2] - boxes[:, 0])
dists.append(boxes[:, 3] - boxes[:, 1])
if boxes.shape[1] // 2 == 3:
dists.append(boxes[:, 5] - boxes[:, 4])
return torch.stack(dists, axis=1)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from numpy import ndarray
def box_area_np(boxes: ndarray) -> ndarray:
"""
See Also:
:func:`nndet.core.boxes.ops.box_area`
"""
if boxes.shape[-1] == 4:
return box_area_2d_np(boxes)
else:
return box_area_3d_np(boxes)
def box_area_3d_np(boxes: np.ndarray) -> np.ndarray:
"""
See Also:
`nndet.core.boxes.ops.box_area_3d`
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4])
def box_area_2d_np(boxes: np.ndarray) -> np.ndarray:
"""
See Also:
`nndet.core.boxes.ops.box_area_2d`
"""
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
def box_iou_np(boxes1: ndarray, boxes2: ndarray) -> ndarray:
"""
Return intersection-over-union (Jaccard index) of boxes.
(Works for ndarrays and Numpy Arrays)
Arguments:
boxes1 (ndarray): boxes; (x1, y1, x2, y2, (z1, z2))[N, dim * 2]
boxes2 (ndarray): boxes; (x1, y1, x2, y2, (z1, z2))[M, dim * 2]
Returns:
iou (ndarray): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2; [N, M]
See Also:
:func:`box_iou_3d`, :func:`torchvision.ops.boxes.box_iou`
"""
# TODO: think about adding additional assert statements to check coordinates x1 <= x2, y1 <= y2, z1 <= z2
if boxes1.shape[-1] == 4:
return box_iou_2d_np(boxes1, boxes2)
else:
return box_iou_3d_np(boxes1, boxes2)
def box_iou_2d_np(boxes1: ndarray, boxes2: ndarray) -> ndarray:
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
boxes1 (ndarray): set of boxes (x1, y1, x2, y2)[N, 4]
boxes2 (ndarray): set of boxes (x1, y1, x2, y2)[M, 4]
Returns:
iou (ndarray[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
area1 = box_area_2d_np(boxes1)
area2 = box_area_2d_np(boxes2)
x1 = np.maximum(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = np.maximum(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = np.minimum(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = np.minimum(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
inter = np.clip((x2 - x1), a_min=0, a_max=None) * np.clip((y2 - y1), a_min=0, a_max=None) # [N, M]
return inter / (area1[:, None] + area2 - inter)
def box_iou_3d_np(boxes1: ndarray, boxes2: ndarray) -> ndarray:
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2, z1, z2) format.
Arguments:
boxes1 (ndarray): set of boxes (x1, y1, x2, y2, z1, z2)[N, 6]
boxes2 (ndarray): set of boxes (x1, y1, x2, y2, z1, z2)[M, 6]
Returns:
iou (ndarray[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
area1 = box_area_3d_np(boxes1)
area2 = box_area_3d_np(boxes2)
x1 = np.maximum(boxes1[:, None, 0], boxes2[:, 0]) # [N, M]
y1 = np.maximum(boxes1[:, None, 1], boxes2[:, 1]) # [N, M]
x2 = np.minimum(boxes1[:, None, 2], boxes2[:, 2]) # [N, M]
y2 = np.minimum(boxes1[:, None, 3], boxes2[:, 3]) # [N, M]
z1 = np.maximum(boxes1[:, None, 4], boxes2[:, 4]) # [N, M]
z2 = np.minimum(boxes1[:, None, 5], boxes2[:, 5]) # [N, M]
inter = np.clip((x2 - x1), a_min=0, a_max=None) * np.clip((y2 - y1), a_min=0, a_max=None) * \
np.clip((z2 - z1), a_min=0, a_max=None) # [N, M]
return inter / (area1[:, None] + area2 - inter)
def box_size_np(boxes: ndarray) -> ndarray:
"""
Compute length of boxes along all dimensions
Args:
boxes (ndarray): boxes (x1, y1, x2, y2, z1, z2)[N, dim * 2]
Returns:
ndarray: size along axis (x, y, (z))[N, dim]
"""
dists = []
dists.append(boxes[:, 2] - boxes[:, 0])
dists.append(boxes[:, 3] - boxes[:, 1])
if boxes.shape[1] // 2 == 3:
dists.append(boxes[:, 5] - boxes[:, 4])
return np.stack(dists, axis=-1)
def box_center_np(boxes: np.ndarray) -> np.ndarray:
"""
Compute center point of boxes
Args:
boxes: bounding boxes (x1, y1, x2, y2, (z1, z2)) [N, dims * 2]
Returns:
Tensor: center points [N, dims]
"""
centers = [(boxes[:, 2] + boxes[:, 0]) / 2., (boxes[:, 3] + boxes[:, 1]) / 2.]
if boxes.shape[1] == 6:
centers.append((boxes[:, 5] + boxes[:, 4]) / 2.)
return np.stack(centers, axis=1)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
from loguru import logger
from abc import ABC
from typing import List
from torch import Tensor
from torchvision.models.detection._utils import BalancedPositiveNegativeSampler
class AbstractSampler(ABC):
def __call__(self, target_labels: List[Tensor], fg_probs: Tensor):
"""
Select positive and negative anchors
Args:
target_labels (List[Tensor]): labels for each anchor per image, List[[A]],
where A is the number of anchors in one image
fg_probs (Tensor): maximum foreground probability per anchor, [R]
where R is the sum of all anchors inside one batch
Returns:
List[Tensor]: binary mask for positive anchors, List[[A]]
List[Tensor]: binary mask for negative anchors, List[[A]]
"""
raise NotImplementedError
class NegativeSampler(BalancedPositiveNegativeSampler, AbstractSampler):
def __call__(self, target_labels: List[Tensor], fg_probs: Tensor):
"""
Randomly sample negatives and positives until batch_size_per_img
is reached
If not enough positive samples are found, it will be padded with
negative samples
"""
return super(NegativeSampler, self).__call__(target_labels)
class HardNegativeSamplerMixin(ABC):
def __init__(self, pool_size: float = 10):
"""
Create a pool from the highest scoring false positives and sample
defined number of negatives from it
Args:
pool_size (float): hard negatives are sampled from a pool of size:
batch_size_per_image * (1 - positive_fraction) * pool_size
"""
self.pool_size = pool_size
def select_negatives(self, negative: Tensor, num_neg: int,
img_labels: Tensor, img_fg_probs: Tensor):
"""
Select negative anchors
Args:
negative (Tensor): indices of negative anchors [P],
where P is the number of negative anchors
num_neg (int): number of negative anchors to sample
img_labels (Tensor): labels for all anchors in a image [A],
where A is the number of anchors in one image
img_fg_probs (Tensor): maximum foreground probability per anchor [A],
where A is the the number of anchors in one image
Returns:
Tensor: binary mask of negative anchors to choose [A],
where A is the the number of anchors in one image
"""
pool = int(num_neg * self.pool_size)
pool = min(negative.numel(), pool) # protect against not enough negatives
# select pool of highest scoring false positives
_, negative_idx_pool = img_fg_probs[negative].topk(pool, sorted=True)
negative = negative[negative_idx_pool]
# select negatives from pool
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]
neg_idx_per_image = negative[perm2]
neg_idx_per_image_mask = torch.zeros_like(img_labels, dtype=torch.uint8)
neg_idx_per_image_mask[neg_idx_per_image] = 1
return neg_idx_per_image_mask
class HardNegativeSampler(HardNegativeSamplerMixin):
def __init__(self, batch_size_per_image: int, positive_fraction: float,
min_neg: int = 0, pool_size: float = 10):
"""
Created a pool from the highest scoring false positives and sample
defined number of negatives from it
Args:
batch_size_per_image (int): number of elements to be selected per image
positive_fraction (float): percentage of positive elements per batch
pool_size (float): hard negatives are sampled from a pool of size:
batch_size_per_image * (1 - positive_fraction) * pool_size
"""
super().__init__(pool_size=pool_size)
self.min_neg = min_neg
self.batch_size_per_image = batch_size_per_image
self.positive_fraction = positive_fraction
def __call__(self, target_labels: List[Tensor], fg_probs: Tensor):
"""
Select hard negatives from list anchors per image
Args:
target_labels (List[Tensor]): labels for each anchor per image, List[[A]],
where A is the number of anchors in one image
fg_probs (Tensor): maximum foreground probability per anchor, [R]
where R is the sum of all anchors inside one batch
Returns:
List[Tensor]: binary mask for positive anchors, List[[A]]
List[Tensor]: binary mask for negative anchors, List[[A]]
"""
anchors_per_image = [anchors_in_image.shape[0] for anchors_in_image in target_labels]
fg_probs = fg_probs.split(anchors_per_image, 0)
pos_idx = []
neg_idx = []
for img_labels, img_fg_probs in zip(target_labels, fg_probs):
positive = torch.where(img_labels >= 1)[0]
negative = torch.where(img_labels == 0)[0]
num_pos = self.get_num_pos(positive)
pos_idx_per_image_mask = self.select_positives(
positive, num_pos, img_labels, img_fg_probs)
pos_idx.append(pos_idx_per_image_mask)
num_neg = self.get_num_neg(negative, num_pos)
neg_idx_per_image_mask = self.select_negatives(
negative, num_neg, img_labels, img_fg_probs)
neg_idx.append(neg_idx_per_image_mask)
return pos_idx, neg_idx
def get_num_pos(self, positive: torch.Tensor) -> int:
"""
Number of positive samples to draw
Args:
positive: indices of positive anchors
Returns:
int: number of postive sample
"""
# positive anchor sampling
num_pos = int(self.batch_size_per_image * self.positive_fraction)
# protect against not enough positive examples
num_pos = min(positive.numel(), num_pos)
return num_pos
def get_num_neg(self, negative: torch.Tensor, num_pos: int) -> int:
"""
Sample enough negatives to fill up :param:`self.batch_size_per_image`
Args:
negative: indices of positive anchors
num_pos: number of positive samples to draw
Returns:
int: number of negative samples
"""
# always assume at least one pos anchor was sampled
num_neg = int(max(1, num_pos) * abs(1 - 1. / float(self.positive_fraction)))
# protect against not enough negative examples and sample at least one neg if possible
num_neg = min(negative.numel(), max(num_neg, self.min_neg))
return num_neg
def select_positives(self, positive: Tensor, num_pos: int,
img_labels: Tensor, img_fg_probs: Tensor):
"""
Select positive anchors
Args:
positive (Tensor): indices of positive anchors [P],
where P is the number of positive anchors
num_pos (int): number of positive anchors to sample
img_labels (Tensor): labels for all anchors in a image [A],
where A is the number of anchors in one image
img_fg_probs (Tensor): maximum foreground probability per anchor [A],
where A is the the number of anchors in one image
Returns:
Tensor: binary mask of positive anchors to choose [A],
where A is the the number of anchors in one image
"""
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
pos_idx_per_image = positive[perm1]
pos_idx_per_image_mask = torch.zeros_like(img_labels, dtype=torch.uint8)
pos_idx_per_image_mask[pos_idx_per_image] = 1
return pos_idx_per_image_mask
class HardNegativeSamplerBatched(HardNegativeSampler):
"""
Samples negatives and positives on a per batch basis
(default sampler only does this on a per image basis)
Note:
:attr:`batch_size_per_image` is manipulated to sample the correct
number of samples per batch, use :attr:`_batch_size_per_image`
to get the number of anchors per image
"""
def __init__(self, batch_size_per_image: int, positive_fraction: float,
min_neg: int = 0, pool_size: float = 10):
"""
Args:
batch_size_per_image (int): number of elements to be selected per image
positive_fraction (float): percentage of positive elements per batch
pool_size (float): hard negatives are sampled from a pool of size:
batch_size_per_image * (1 - positive_fraction) * pool_size
"""
super().__init__(min_neg=min_neg, batch_size_per_image=batch_size_per_image,
positive_fraction=positive_fraction, pool_size=pool_size)
self._batch_size_per_image = batch_size_per_image
logger.info("Sampling hard negatives on a per batch basis")
def __call__(self, target_labels: List[Tensor], fg_probs: Tensor):
"""
Select hard negatives from list anchors per image
Args:
target_labels (List[Tensor]): labels for each anchor per image, List[[A]],
where A is the number of anchors in one image
fg_probs (Tensor): maximum foreground probability per anchor, [R]
where R is the sum of all anchors inside one batch
Returns:
List[Tensor]: binary mask for positive anchors, List[[A]]
List[Tensor]: binary mask for negative anchors, List[[A]]
"""
batch_size = len(target_labels)
self.batch_size_per_image = self._batch_size_per_image * batch_size
target_labels_batch = torch.cat(target_labels, dim=0)
positive = torch.where(target_labels_batch >= 1)[0]
negative = torch.where(target_labels_batch == 0)[0]
num_pos = self.get_num_pos(positive)
pos_idx = self.select_positives(
positive, num_pos, target_labels_batch, fg_probs)
num_neg = self.get_num_neg(negative, num_pos)
neg_idx = self.select_negatives(
negative, num_neg, target_labels_batch, fg_probs)
# Comb Head with sampling concatenates masks after sampling so do not split them here
# anchors_per_image = [anchors_in_image.shape[0] for anchors_in_image in target_labels]
# return pos_idx.split(anchors_per_image, 0), neg_idx.split(anchors_per_image, 0)
return [pos_idx], [neg_idx]
class BalancedHardNegativeSampler(HardNegativeSampler):
def get_num_neg(self, negative: torch.Tensor, num_pos: int) -> int:
"""
Sample same number of negatives as positives but at least one
Args:
negative: indices of positive anchors
num_pos: number of positive samples to draw
Returns:
int: number of negative samples
"""
# protect against not enough negative examples and sample at least one neg if possible
num_neg = min(negative.numel(), max(num_pos, 1))
return num_neg
class HardNegativeSamplerFgAll(HardNegativeSamplerMixin):
def __init__(self, negative_ratio: float = 1, pool_size: float = 10):
"""
Use all positive anchors for loss and sample corresponding number
of hard negatives
Args:
negative_ratio (float): ratio of negative to positive sample;
(samples negative_ratio * positive_anchors examples)
pool_size (float): hard negatives are sampled from a pool of size:
batch_size_per_image * (1 - positive_fraction) * pool_size
"""
super().__init__(pool_size=pool_size)
self.negative_ratio = negative_ratio
def __call__(self, target_labels: List[Tensor], fg_probs: Tensor):
"""
Select hard negatives from list anchors per image
Args:
target_labels (List[Tensor]): labels for each anchor per image, List[[A]],
where A is the number of anchors in one image
fg_probs (Tensor): maximum foreground probability per anchor, [R]
where R is the sum of all anchors inside one batch
Returns:
List[Tensor]: binary mask for positive anchors, List[[A]]
List[Tensor]: binary mask for negative anchors, List[[A]]
"""
anchors_per_image = [anchors_in_image.shape[0] for anchors_in_image in target_labels]
fg_probs = fg_probs.split(anchors_per_image, 0)
pos_idx = []
neg_idx = []
for img_labels, img_fg_probs in zip(target_labels, fg_probs):
negative = torch.where(img_labels == 0)[0]
# positive anchor sampling
pos_idx_per_image_mask = (img_labels >= 1).to(dtype=torch.uint8)
pos_idx.append(pos_idx_per_image_mask)
num_neg = int(self.negative_ratio * pos_idx_per_image_mask.sum())
# protect against not enough negative examples and sample at least one neg if possible
num_neg = min(negative.numel(), max(num_neg, 1))
neg_idx_per_image_mask = self.select_negatives(
negative, num_neg, img_labels, img_fg_probs)
neg_idx.append(neg_idx_per_image_mask)
return pos_idx, neg_idx
import torch
import torch.nn as nn
from torch import Tensor
from typing import List, Tuple, Dict, Any, Optional, Union
from nndet.arch.abstract import AbstractModel
from nndet.core import boxes as box_utils
from nndet.arch.encoder.abstract import EncoderType
from nndet.arch.decoder.base import DecoderType
from nndet.arch.heads.segmenter import SegmenterType
from nndet.arch.heads.comb import HeadType
from nndet.core.boxes.anchors import AnchorGeneratorType
class BaseRetinaNet(AbstractModel):
def __init__(self,
dim: int,
# modules
encoder: EncoderType,
decoder: DecoderType,
head: HeadType,
num_classes: int,
anchor_generator: AnchorGeneratorType,
matcher: box_utils.MatcherType,
decoder_levels: tuple = (2, 3, 4, 5),
# post-processing
score_thresh: float = None,
detections_per_img: int = 100,
topk_candidates: int = 10000,
remove_small_boxes: float = 1e-2,
nms_thresh: float = 0.9,
# optional
segmenter: Optional[SegmenterType] = None,
):
"""
Base Retina(U)Net
Can be subclasses to add specific configurations to it
Args:
dim: number of spatial dimensions
encoder: encoder module
decoder: decoder module
head: head module
num_classes: number of foreground classes
anchor_generator: generate anchors
matcher: match ground truth boxes and anchors
decoder_levels: decoder levels to use for detection prediciton
score_thresh: minimum output probability
detections_per_img: max detections per image
topk_candidates: select only topk candidates for nms computation
remove_small_boxes: remove small bounding boxes
nms_thresh: non maximum suppression threshold
segmenter: segmentation module
"""
super().__init__()
assert dim in [2, 3]
self.dim = dim
self.decoder_levels = decoder_levels
self.encoder = encoder
self.decoder = decoder
self.head = head
self.num_foreground_classes = num_classes
self.anchor_generator = anchor_generator
self.proposal_matcher = matcher
self.score_thresh = score_thresh
self.topk_candidates = topk_candidates
self.detections_per_img = detections_per_img
self.remove_small_boxes = remove_small_boxes
self.nms_thresh = nms_thresh
self.segmenter = segmenter
def train_step(self,
images: Tensor,
targets: dict,
evaluation: bool,
batch_num: int,
) -> Tuple[
Dict[str, torch.Tensor], Optional[Dict]]:
"""
Perform a single training step (forward pass + loss computation)
Args:
images: batch of images
targets: labels for training
`target_boxes` (List[Tensor]): ground truth bounding boxes
(x1, y1, x2, y2, (z1, z2))[X, dim * 2], X= number of ground
truth boxes in image
`target_classes` (List[Tensor]): ground truth class per box
(classes start from 0) [X], X= number of ground truth
boxes in image
`target_seg`(Tensor): segmentation ground truth
(only needed if :param:`segmenter`
was provided in init) (classes start from 1, 0 background)
evaluation (bool): compute final predictions (includes detection
postprocessing)
batch_num (int): batch index inside epoch
Returns:
torch.Tensor: final loss for back propagation
Dict: predictions for metric calculation
'pred_boxes': List[Tensor]: predicted bounding boxes for each
image List[[R, dim * 2]]
'pred_scores': List[Tensor]: predicted probability for the
class List[[R]]
'pred_labels': List[Tensor]: predicted class List[[R]]
'pred_seg': Tensor: predicted segmentation [N, dims]
Dict[str, torch.Tensor]: scalars for logging (e.g. individual
loss components)
"""
# import napari
# with napari.gui_qt():
# viewer = napari.view_image(images.detach().cpu().numpy())
# viewer.add_labels(seg_targets[:, None].detach().cpu().numpy())
target_boxes: List[Tensor] = targets["target_boxes"]
target_classes: List[Tensor] = targets["target_classes"]
target_seg: Tensor = targets["target_seg"]
pred_detection, anchors, pred_seg = self(images)
labels, matched_gt_boxes = self.assign_targets_to_anchors(
anchors, target_boxes, target_classes)
losses = {}
head_losses, pos_idx, neg_idx = self.head.compute_loss(
pred_detection, labels, matched_gt_boxes, anchors)
losses.update(head_losses)
if self.segmenter is not None:
losses.update(self.segmenter.compute_loss(pred_seg, target_seg))
if evaluation:
prediction = self.postprocess_for_inference(
images=images,
pred_detection=pred_detection,
pred_seg=pred_seg,
anchors=anchors,
)
else:
prediction = None
# self.save_matched_anchors(images=images, target_boxes=target_boxes,
# anchors=anchors, pos_idx=pos_idx,
# neg_idx=neg_idx, seg=seg_targets)
return losses, prediction
@torch.no_grad()
def postprocess_for_inference(self,
images: torch.Tensor,
pred_detection: Dict[str, torch.Tensor],
pred_seg: Dict[str, torch.Tensor],
anchors: List[torch.Tensor],
) -> Dict[str, Union[List[Tensor], Tensor]]:
"""
Postprocess predictions for inference
Args:
images: input images
pred_detection: detection predictions
pred_seg: segmentation predictions
anchors: anchors
Returns:
Dict: post processed predictions
'pred_boxes': List[Tensor]: predicted bounding boxes for each
image List[[R, dim * 2]]
'pred_scores': List[Tensor]: predicted probability for
the class List[[R]]
'pred_labels': List[Tensor]: predicted class List[[R]]
'pred_seg': Tensor: predicted segmentation [N, C, dims]
"""
image_shapes = [images.shape[2:]] * images.shape[0]
boxes, probs, labels = self.postprocess_detections(
pred_detection=pred_detection,
anchors=anchors,
image_shapes=image_shapes,
)
prediction = {"pred_boxes": boxes, "pred_scores": probs, "pred_labels": labels}
if self.segmenter is not None:
prediction["pred_seg"] = self.segmenter.postprocess_for_inference(pred_seg)["pred_seg"]
return prediction
def forward(self,
inp: torch.Tensor,
) -> Tuple[Dict[str, torch.Tensor], List[torch.Tensor], Dict[str, torch.Tensor]]:
"""
Compute predicted bounding boxes, scores and segmentations
Args:
inp (torch.Tensor): batch of input images
Returns:
dict: predictions from head. Typically includes:
`box_deltas`(Tensor): bounding box offsets
[Num_Anchors_Batch, (dim * 2)]
`box_logits`(Tensor): classification logits
[Num_Anchors_Batch, (num_classes)]
List[torch.Tensor]: list of anchors (for each image inside the
batch)
dict: segmentation prediction. None if retina net is configured.
Typically includes:
`seg_logits`: segmentation logits
"""
features_maps_all = self.decoder(self.encoder(inp))
feature_maps_head = [features_maps_all[i] for i in self.decoder_levels]
pred_detection = self.head(feature_maps_head)
anchors = self.anchor_generator(inp, feature_maps_head)
pred_seg = self.segmenter(features_maps_all) if self.segmenter is not None else None
return pred_detection, anchors, pred_seg
@torch.no_grad()
def assign_targets_to_anchors(self,
anchors: List[torch.Tensor],
target_boxes: List[torch.Tensor],
target_classes: List[torch.Tensor]) -> Tuple[
List[torch.Tensor], List[torch.Tensor]]:
"""
Compute labels and matched ground truth for each anchor
Adapted from torchvision https://github.com/pytorch/vision
Args:
anchors (List[torch.Tensor[float]]): anchors (!)per image(!)
List[[N, dim * 2]], N=number of anchors per image
target_boxes (List[torch.Tensor[float]]): ground truth boxes
(!)per image(!)
List[[X, dim * 2]], X=number of gt per image
target_classes (List[torch.Tensor): ground truth classes
(!)per image(!) (classes start from 0)
List[[X]], X=number of gt per image
Returns:
List[torch.Tensor]: labels ([1, K]: foreground classes, 0: background,
-1: between) List[[N]], N=number of anchors per image
List[torch.Tensor]: matched gt box List[[N, dim * 2]],
N=number of anchors per image
"""
labels = []
matched_gt_boxes = []
for anchors_per_image, gt_boxes, gt_classes in zip(anchors, target_boxes, target_classes):
# indices of ground truth box for each proposal
match_quality_matrix, matched_idxs = self.proposal_matcher(
gt_boxes, anchors_per_image,
num_anchors_per_level=self.anchor_generator.get_num_acnhors_per_level(),
num_anchors_per_loc=self.anchor_generator.num_anchors_per_location()[0])
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
if match_quality_matrix.numel() > 0:
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
# Positive (negative indices can be ignored because they are overwritten in the next step)
# this influences how background class is handled in the input!!!! (here +1 for background)
labels_per_image = gt_classes[matched_idxs.clamp(min=0)].to(dtype=anchors_per_image.dtype)
labels_per_image = labels_per_image + 1
else:
num_anchors_per_image = anchors_per_image.shape[0]
# no ground truth => no matches, all background
matched_gt_boxes_per_image = torch.zeros_like(anchors_per_image)
labels_per_image = torch.zeros(num_anchors_per_image).to(anchors_per_image)
# Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = 0.0
# discard indices that are between thresholds
inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_per_image[inds_to_discard] = -1.0
labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image)
return labels, matched_gt_boxes
def postprocess_detections(self,
pred_detection: Dict[str, Tensor],
anchors: List[Tensor],
image_shapes: List[Tuple[int]],
) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
"""
Postprocess bounding box deltas and logits to generate final boxes and
scores
Adapted from torchvision https://github.com/pytorch/vision
Args:
pred_detection: detection predictions for loss computation
`box_logits`: classification logits for each anchor [N]
`box_deltas`: offsets for each anchor
(x1, y1, x2, y2, (z1, z2))[N, dim * 2]
anchors: proposals for each image
image_shapes: shape of each image
Returns:
List[Tensor]: final boxes [R, dim * 2]
List[Tensor]: final scores (for final class) [R]
List[Tensor]: final class label [R]
"""
boxes_per_image = [len(boxes_in_image) for boxes_in_image in anchors]
pred_detection = self.head.postprocess_for_inference(pred_detection, anchors)
pred_boxes, pred_probs = pred_detection["pred_boxes"], pred_detection["pred_probs"]
# split boxes and scores per image
pred_boxes = pred_boxes.split(boxes_per_image, 0)
pred_probs = pred_probs.split(boxes_per_image, 0)
all_boxes, all_probs, all_labels = [], [], []
# iterate over images
for boxes, probs, image_shape in zip(pred_boxes, pred_probs, image_shapes):
boxes, probs, labels = self.postprocess_detections_single_image(boxes, probs, image_shape)
all_boxes.append(boxes)
all_probs.append(probs)
all_labels.append(labels)
return all_boxes, all_probs, all_labels
def postprocess_detections_single_image(
self,
boxes: Tensor,
probs: Tensor,
image_shape: Tuple[int],
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Postprocess bounding box deltas and probabilities for a single image
Adapted from torchvision https://github.com/pytorch/vision
Args:
boxes: predicted deltas for proposals [N, dim * 2]
probs: predicted logits for boxes [N, C]
image_shape: shape of image
Returns:
Tensor: final boxes [R, dim * 2]
Tensor: final scores (for final class) [R]
Tensor: final class label [R]
"""
assert boxes.shape[0] == probs.shape[0]
boxes = box_utils.clip_boxes_to_image_(boxes, image_shape)
probs = probs.flatten()
if self.topk_candidates is not None:
num_topk = min(self.topk_candidates, boxes.size(0))
probs, idx = probs.sort(descending=True)
probs, idx = probs[:num_topk], idx[:num_topk]
else:
idx = torch.arange(probs.numel())
if self.score_thresh is not None:
keep_idxs = probs > self.score_thresh
probs, idx = probs[keep_idxs], idx[keep_idxs]
anchor_idxs = idx // self.num_foreground_classes
labels = idx % self.num_foreground_classes
boxes = boxes[anchor_idxs]
if self.remove_small_boxes is not None:
keep = box_utils.remove_small_boxes(boxes, min_size=self.remove_small_boxes)
boxes, probs, labels = boxes[keep], probs[keep], labels[keep]
keep = box_utils.batched_nms(boxes, probs, labels, self.nms_thresh)
if self.detections_per_img is not None:
keep = keep[:self.detections_per_img]
return boxes[keep], probs[keep], labels[keep]
# @torch.no_grad()
# def save_matched_anchors(self, **kwargs):
# logger = get_logger("mllogger")
# logger.save_pickle("anchor_matching",
# to_device(kwargs, device="cpu", detach=True))
@torch.no_grad()
def inference_step(self,
images: Tensor,
**kwargs,
) -> Dict[str, Any]:
"""
Perform inference for a batch of images
Args:
images: batch of input images [N, C, W, H, (D)]
Returns:
Dict:
'pred_boxes': List[Tensor]: predicted bounding boxes for each
image List[[R, dim * 2]]
'pred_scores': List[Tensor]: predicted probability for
the class List[[R]]
'pred_labels': List[Tensor]: predicted class List[[R]]
'pred_seg': Tensor: predicted segmentation [N, C, dims]
"""
pred_detection, anchors, pred_seg = self(images)
prediction = self.postprocess_for_inference(
images=images,
pred_detection=pred_detection,
pred_seg=pred_seg,
anchors=anchors,
)
return prediction
/* adopted from
https://github.com/pytorch/vision/blob/master/torchvision/csrc/nms.h on Nov 15 2019
no cpu support, but could be added with this interface.
*/
//#include "cpu/vision_cpu.h"
#include <torch/types.h>
at::Tensor nms_cuda(const at::Tensor& dets, const at::Tensor& scores, float iou_threshold);
at::Tensor nms(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
if (dets.device().is_cuda()) {
if (dets.numel() == 0) {
//at::cuda::CUDAGuard device_guard(dets.device());
return at::empty({0}, dets.options().dtype(at::kLong));
}
return nms_cuda(dets, scores, iou_threshold);
}
AT_ERROR("Not compiled with CPU support");
//at::Tensor result = nms_cpu(dets, scores, iou_threshold);
//return result;
}
#pragma once
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \
i += (blockDim.x * gridDim.x))
/*
NMS implementation in CUDA from pytorch framework
(https://github.com/pytorch/vision/tree/master/torchvision/csrc/cuda on Nov 13 2019)
Adapted for additional 3D capability by G. Ramien, DKFZ Heidelberg
Parts of this code are from torchvision and thus licensed under
BSD 3-Clause License
Copyright (c) Soumith Chintala 2016,
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include "cuda_helpers.h"
#include <iostream>
#include <vector>
int const threadsPerBlock = sizeof(unsigned long long) * 8;
template <typename T>
__device__ inline float devIoU(T const* const a, T const* const b) {
// a, b hold box coords as (y1, x1, y2, x2) with y1 < y2 etc.
T bottom = max(a[0], b[0]), top = min(a[2], b[2]);
T left = max(a[1], b[1]), right = min(a[3], b[3]);
T width = max(right - left, (T)0), height = max(top - bottom, (T)0);
T interS = width * height;
T Sa = (a[2] - a[0]) * (a[3] - a[1]);
T Sb = (b[2] - b[0]) * (b[3] - b[1]);
return interS / (Sa + Sb - interS);
}
template <typename T>
__device__ inline float devIoU_3d(T const* const a, T const* const b) {
// a, b hold box coords as (y1, x1, y2, x2, z1, z2) with y1 < y2 etc.
// get coordinates of intersection, calc intersection
T bottom = max(a[0], b[0]), top = min(a[2], b[2]);
T left = max(a[1], b[1]), right = min(a[3], b[3]);
T front = max(a[4], b[4]), back = min(a[5], b[5]);
T width = max(right - left, (T)0), height = max(top - bottom, (T)0);
T depth = max(back - front, (T)0);
T interS = width * height * depth;
// calc separate boxes volumes
T Sa = (a[2] - a[0]) * (a[3] - a[1]) * (a[5] - a[4]);
T Sb = (b[2] - b[0]) * (b[3] - b[1]) * (b[5] - b[4]);
return interS / (Sa + Sb - interS);
}
template <typename T>
__global__ void nms_kernel(const int n_boxes, const float iou_threshold, const T* dev_boxes,
unsigned long long* dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
__shared__ T block_boxes[threadsPerBlock * 4];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 4 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0];
block_boxes[threadIdx.x * 4 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1];
block_boxes[threadIdx.x * 4 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2];
block_boxes[threadIdx.x * 4 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 4;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (devIoU<T>(cur_box, block_boxes + i * 4) > iou_threshold) {
t |= 1ULL << i;
}
}
const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
template <typename T>
__global__ void nms_kernel_3d(const int n_boxes, const float iou_threshold, const T* dev_boxes,
unsigned long long* dev_mask) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
__shared__ T block_boxes[threadsPerBlock * 6];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 6 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 0];
block_boxes[threadIdx.x * 6 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 1];
block_boxes[threadIdx.x * 6 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 2];
block_boxes[threadIdx.x * 6 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 3];
block_boxes[threadIdx.x * 6 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 4];
block_boxes[threadIdx.x * 6 + 5] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 5];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 6;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (devIoU_3d<T>(cur_box, block_boxes + i * 6) > iou_threshold) {
t |= 1ULL << i;
}
}
const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
at::Tensor nms_cuda(const at::Tensor& dets, const at::Tensor& scores, float iou_threshold) {
/* dets expected as (n_dets, dim) where dim=4 in 2D, dim=6 in 3D */
AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor");
AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(dets.device());
bool is_3d = dets.size(1) == 6;
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t);
int dets_num = dets.size(0);
const int col_blocks = at::cuda::ATenCeilDiv(dets_num, threadsPerBlock);
at::Tensor mask =
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (is_3d) {
//std::cout << "performing NMS on 3D boxes in CUDA" << std::endl;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
dets_sorted.type(), "nms_kernel_cuda", [&] {
nms_kernel_3d<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num,
iou_threshold,
dets_sorted.data_ptr<scalar_t>(),
(unsigned long long*)mask.data_ptr<int64_t>());
});
}
else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
dets_sorted.type(), "nms_kernel_cuda", [&] {
nms_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num,
iou_threshold,
dets_sorted.data_ptr<scalar_t>(),
(unsigned long long*)mask.data_ptr<int64_t>());
});
}
at::Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host = (unsigned long long*)mask_cpu.data_ptr<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
at::Tensor keep =
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();
int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
AT_CUDA_CHECK(cudaGetLastError());
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
}
\ No newline at end of file
#include <torch/extension.h>
#include "cpu/nms.cpp"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms", &nms, "NMS C++ and/or CUDA");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment