_utils.py 1.17 KB
Newer Older
1
from collections import OrderedDict
limm's avatar
limm committed
2
from typing import Dict, Optional
3

limm's avatar
limm committed
4
from torch import nn, Tensor
5
6
from torch.nn import functional as F

limm's avatar
limm committed
7
8
from ...utils import _log_api_usage_once

9
10

class _SimpleSegmentationModel(nn.Module):
limm's avatar
limm committed
11
    __constants__ = ["aux_classifier"]
eellison's avatar
eellison committed
12

limm's avatar
limm committed
13
14
15
    def __init__(self, backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None) -> None:
        super().__init__()
        _log_api_usage_once(self)
16
17
18
19
        self.backbone = backbone
        self.classifier = classifier
        self.aux_classifier = aux_classifier

limm's avatar
limm committed
20
    def forward(self, x: Tensor) -> Dict[str, Tensor]:
21
22
23
24
25
26
27
        input_shape = x.shape[-2:]
        # contract: features is a dict of tensors
        features = self.backbone(x)

        result = OrderedDict()
        x = features["out"]
        x = self.classifier(x)
limm's avatar
limm committed
28
        x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
29
30
31
32
33
        result["out"] = x

        if self.aux_classifier is not None:
            x = features["aux"]
            x = self.aux_classifier(x)
limm's avatar
limm committed
34
            x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
35
36
37
            result["aux"] = x

        return result