generalized_rcnn.py 4.48 KB
Newer Older
1
2
3
4
"""
Implements the Generalized R-CNN framework
"""

5
import warnings
6
from collections import OrderedDict
7
8
from typing import Tuple, List, Dict, Optional, Union

9
import torch
10
from torch import nn, Tensor
11

12
13
from ...utils import _log_api_usage_once

14
15
16
17
18

class GeneralizedRCNN(nn.Module):
    """
    Main class for Generalized R-CNN.

19
    Args:
20
21
        backbone (nn.Module):
        rpn (nn.Module):
22
        roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
23
24
25
26
27
            detections / masks from it.
        transform (nn.Module): performs the data transformation from the inputs to feed into
            the model
    """

28
    def __init__(self, backbone: nn.Module, rpn: nn.Module, roi_heads: nn.Module, transform: nn.Module) -> None:
29
        super().__init__()
30
        _log_api_usage_once(self)
31
32
33
34
        self.transform = transform
        self.backbone = backbone
        self.rpn = rpn
        self.roi_heads = roi_heads
35
36
        # used only on torchscript mode
        self._has_warned = False
37

eellison's avatar
eellison committed
38
    @torch.jit.unused
39
40
41
42
43
44
    def eager_outputs(
        self,
        losses: Dict[str, Tensor],
        detections: List[Dict[str, Tensor]],
    ) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]:

eellison's avatar
eellison committed
45
46
47
48
49
        if self.training:
            return losses

        return detections

50
51
52
53
54
    def forward(
        self,
        images: List[Tensor],
        targets: Optional[List[Dict[str, Tensor]]] = None,
    ) -> Union[Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]], Dict[str, Tensor], List[Dict[str, Tensor]]]:
55
        """
56
        Args:
57
            images (list[Tensor]): images to be processed
58
            targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
59
60
61
62
63
64
65
66
67
68

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).

        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")
69
70
71
72
73
74
        if self.training:
            assert targets is not None
            for target in targets:
                boxes = target["boxes"]
                if isinstance(boxes, torch.Tensor):
                    if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
75
                        raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.")
76
                else:
77
                    raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
78

79
        original_image_sizes: List[Tuple[int, int]] = []
eellison's avatar
eellison committed
80
81
82
83
84
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))

85
        images, targets = self.transform(images, targets)
86
87
88
89
90
91
92
93

        # Check for degenerate boxes
        # TODO: Move this to a function
        if targets is not None:
            for target_idx, target in enumerate(targets):
                boxes = target["boxes"]
                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
                if degenerate_boxes.any():
94
                    # print the first degenerate box
95
                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
96
                    degen_bb: List[float] = boxes[bb_idx].tolist()
97
98
                    raise ValueError(
                        "All bounding boxes should have positive height and width."
99
                        f" Found invalid box {degen_bb} for target at index {target_idx}."
100
                    )
101

102
103
        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
104
            features = OrderedDict([("0", features)])
105
106
        proposals, proposal_losses = self.rpn(images, features, targets)
        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
107
        detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)  # type: ignore[operator]
108
109
110
111
112

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)

eellison's avatar
eellison committed
113
        if torch.jit.is_scripting():
114
115
116
            if not self._has_warned:
                warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
117
            return losses, detections
eellison's avatar
eellison committed
118
119
        else:
            return self.eager_outputs(losses, detections)