generalized_rcnn.py 4.57 KB
Newer Older
1
2
3
4
5
6
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Implements the Generalized R-CNN framework
"""

from collections import OrderedDict
7
from typing import Union
8
9
import torch
from torch import nn
eellison's avatar
eellison committed
10
11
12
import warnings
from torch.jit.annotations import Tuple, List, Dict, Optional
from torch import Tensor
13
14
15
16
17
18
19
20
21


class GeneralizedRCNN(nn.Module):
    """
    Main class for Generalized R-CNN.

    Arguments:
        backbone (nn.Module):
        rpn (nn.Module):
22
        roi_heads (nn.Module): takes the features + the proposals from the RPN and computes
23
24
25
26
27
28
29
30
31
32
33
            detections / masks from it.
        transform (nn.Module): performs the data transformation from the inputs to feed into
            the model
    """

    def __init__(self, backbone, rpn, roi_heads, transform):
        super(GeneralizedRCNN, self).__init__()
        self.transform = transform
        self.backbone = backbone
        self.rpn = rpn
        self.roi_heads = roi_heads
34
35
        # used only on torchscript mode
        self._has_warned = False
36

eellison's avatar
eellison committed
37
38
    @torch.jit.unused
    def eager_outputs(self, losses, detections):
39
        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Union[Dict[str, Tensor], List[Dict[str, Tensor]]]
eellison's avatar
eellison committed
40
41
42
43
44
        if self.training:
            return losses

        return detections

45
    def forward(self, images, targets=None):
46
        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        """
        Arguments:
            images (list[Tensor]): images to be processed
            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).

        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")
61
62
63
64
65
66
67
68
69
70
71
72
73
        if self.training:
            assert targets is not None
            for target in targets:
                boxes = target["boxes"]
                if isinstance(boxes, torch.Tensor):
                    if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
                        raise ValueError("Expected target boxes to be a tensor"
                                         "of shape [N, 4], got {:}.".format(
                                             boxes.shape))
                else:
                    raise ValueError("Expected target boxes to be of type "
                                     "Tensor, got {:}.".format(type(boxes)))

eellison's avatar
eellison committed
74
75
76
77
78
79
        original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))

80
        images, targets = self.transform(images, targets)
81
82
83
84
85
86
87
88

        # Check for degenerate boxes
        # TODO: Move this to a function
        if targets is not None:
            for target_idx, target in enumerate(targets):
                boxes = target["boxes"]
                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
                if degenerate_boxes.any():
89
                    # print the first degenerate box
90
91
92
                    bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0]
                    degen_bb: List[float] = boxes[bb_idx].tolist()
                    raise ValueError("All bounding boxes should have positive height and width."
93
                                     " Found invalid box {} for target at index {}."
94
95
                                     .format(degen_bb, target_idx))

96
97
        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
98
            features = OrderedDict([('0', features)])
99
100
101
102
103
104
105
106
        proposals, proposal_losses = self.rpn(images, features, targets)
        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
        detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)

eellison's avatar
eellison committed
107
        if torch.jit.is_scripting():
108
109
110
            if not self._has_warned:
                warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
eellison's avatar
eellison committed
111
112
113
            return (losses, detections)
        else:
            return self.eager_outputs(losses, detections)