generalized_rcnn.py 3.11 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""
Implements the Generalized R-CNN framework
"""

from collections import OrderedDict
import torch
from torch import nn
eellison's avatar
eellison committed
9
10
11
import warnings
from torch.jit.annotations import Tuple, List, Dict, Optional
from torch import Tensor
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32


class GeneralizedRCNN(nn.Module):
    """
    Main class for Generalized R-CNN.

    Arguments:
        backbone (nn.Module):
        rpn (nn.Module):
        heads (nn.Module): takes the features + the proposals from the RPN and computes
            detections / masks from it.
        transform (nn.Module): performs the data transformation from the inputs to feed into
            the model
    """

    def __init__(self, backbone, rpn, roi_heads, transform):
        super(GeneralizedRCNN, self).__init__()
        self.transform = transform
        self.backbone = backbone
        self.rpn = rpn
        self.roi_heads = roi_heads
33
34
        # used only on torchscript mode
        self._has_warned = False
35

eellison's avatar
eellison committed
36
37
38
39
40
41
42
43
    @torch.jit.unused
    def eager_outputs(self, losses, detections):
        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        if self.training:
            return losses

        return detections

44
    def forward(self, images, targets=None):
eellison's avatar
eellison committed
45
        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]])
46
47
48
49
50
51
52
53
54
55
56
57
58
59
        """
        Arguments:
            images (list[Tensor]): images to be processed
            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).

        """
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")
eellison's avatar
eellison committed
60
61
62
63
64
65
        original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
        for img in images:
            val = img.shape[-2:]
            assert len(val) == 2
            original_image_sizes.append((val[0], val[1]))

66
67
68
        images, targets = self.transform(images, targets)
        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
69
            features = OrderedDict([('0', features)])
70
71
72
73
74
75
76
77
        proposals, proposal_losses = self.rpn(images, features, targets)
        detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
        detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)

eellison's avatar
eellison committed
78
        if torch.jit.is_scripting():
79
80
81
            if not self._has_warned:
                warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
eellison's avatar
eellison committed
82
83
84
            return (losses, detections)
        else:
            return self.eager_outputs(losses, detections)