"tests/vscode:/vscode.git/clone" did not exist on "af0aaddf89c827bd8ed3f4f322066a478b9cc43e"
_utils.py 1.46 KB
Newer Older
1
from collections import OrderedDict
2
from typing import Optional, Dict
3

4
from torch import nn, Tensor
5
6
from torch.nn import functional as F

7
8
from ..._internally_replaced_utils import load_state_dict_from_url

9
10

class _SimpleSegmentationModel(nn.Module):
11
12
13
    __constants__ = ["aux_classifier"]

    def __init__(self, backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None) -> None:
14
        super().__init__()
15
16
17
18
        self.backbone = backbone
        self.classifier = classifier
        self.aux_classifier = aux_classifier

19
    def forward(self, x: Tensor) -> Dict[str, Tensor]:
20
21
22
23
24
25
26
        input_shape = x.shape[-2:]
        # contract: features is a dict of tensors
        features = self.backbone(x)

        result = OrderedDict()
        x = features["out"]
        x = self.classifier(x)
27
        x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
28
29
30
31
32
        result["out"] = x

        if self.aux_classifier is not None:
            x = features["aux"]
            x = self.aux_classifier(x)
33
            x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
34
35
36
            result["aux"] = x

        return result
37
38
39
40


def _load_weights(arch: str, model: nn.Module, model_url: Optional[str], progress: bool) -> None:
    if model_url is None:
41
        raise ValueError(f"No checkpoint is available for {arch}")
42
43
    state_dict = load_state_dict_from_url(model_url, progress=progress)
    model.load_state_dict(state_dict)