s3d.py 7.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from functools import partial
from typing import Any, Callable, Optional

import torch
from torch import nn
from torchvision.ops.misc import Conv3dNormActivation

from ...transforms._presets import VideoClassification
from ...utils import _log_api_usage_once
from .._api import register_model, Weights, WeightsEnum
from .._meta import _KINETICS400_CATEGORIES
from .._utils import _ovewrite_named_param


__all__ = [
    "S3D",
    "S3D_Weights",
    "s3d",
]


class TemporalSeparableConv(nn.Sequential):
    def __init__(
        self,
        in_planes: int,
        out_planes: int,
        kernel_size: int,
        stride: int,
        padding: int,
        norm_layer: Callable[..., nn.Module],
    ):
        super().__init__(
            Conv3dNormActivation(
                in_planes,
                out_planes,
                kernel_size=(1, kernel_size, kernel_size),
                stride=(1, stride, stride),
                padding=(0, padding, padding),
                bias=False,
                norm_layer=norm_layer,
            ),
            Conv3dNormActivation(
                out_planes,
                out_planes,
                kernel_size=(kernel_size, 1, 1),
                stride=(stride, 1, 1),
                padding=(padding, 0, 0),
                bias=False,
                norm_layer=norm_layer,
            ),
        )


class SepInceptionBlock3D(nn.Module):
    def __init__(
        self,
        in_planes: int,
        b0_out: int,
        b1_mid: int,
        b1_out: int,
        b2_mid: int,
        b2_out: int,
        b3_out: int,
        norm_layer: Callable[..., nn.Module],
    ):
        super().__init__()

        self.branch0 = Conv3dNormActivation(in_planes, b0_out, kernel_size=1, stride=1, norm_layer=norm_layer)
        self.branch1 = nn.Sequential(
            Conv3dNormActivation(in_planes, b1_mid, kernel_size=1, stride=1, norm_layer=norm_layer),
            TemporalSeparableConv(b1_mid, b1_out, kernel_size=3, stride=1, padding=1, norm_layer=norm_layer),
        )
        self.branch2 = nn.Sequential(
            Conv3dNormActivation(in_planes, b2_mid, kernel_size=1, stride=1, norm_layer=norm_layer),
            TemporalSeparableConv(b2_mid, b2_out, kernel_size=3, stride=1, padding=1, norm_layer=norm_layer),
        )
        self.branch3 = nn.Sequential(
            nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
            Conv3dNormActivation(in_planes, b3_out, kernel_size=1, stride=1, norm_layer=norm_layer),
        )

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        out = torch.cat((x0, x1, x2, x3), 1)

        return out


class S3D(nn.Module):
    """S3D main class.

    Args:
        num_class (int): number of classes for the classification task.
        dropout (float): dropout probability.
        norm_layer (Optional[Callable]): Module specifying the normalization layer to use.

    Inputs:
        x (Tensor): batch of videos with dimensions (batch, channel, time, height, width)
    """

    def __init__(
        self,
        num_classes: int = 400,
Vasilis Vryniotis's avatar
Vasilis Vryniotis committed
107
        dropout: float = 0.2,
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
        norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
    ) -> None:
        super().__init__()
        _log_api_usage_once(self)

        if norm_layer is None:
            norm_layer = partial(nn.BatchNorm3d, eps=0.001, momentum=0.001)

        self.features = nn.Sequential(
            TemporalSeparableConv(3, 64, 7, 2, 3, norm_layer),
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            Conv3dNormActivation(
                64,
                64,
                kernel_size=1,
                stride=1,
                norm_layer=norm_layer,
            ),
            TemporalSeparableConv(64, 192, 3, 1, 1, norm_layer),
            nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            SepInceptionBlock3D(192, 64, 96, 128, 16, 32, 32, norm_layer),
            SepInceptionBlock3D(256, 128, 128, 192, 32, 96, 64, norm_layer),
            nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
            SepInceptionBlock3D(480, 192, 96, 208, 16, 48, 64, norm_layer),
            SepInceptionBlock3D(512, 160, 112, 224, 24, 64, 64, norm_layer),
            SepInceptionBlock3D(512, 128, 128, 256, 24, 64, 64, norm_layer),
            SepInceptionBlock3D(512, 112, 144, 288, 32, 64, 64, norm_layer),
            SepInceptionBlock3D(528, 256, 160, 320, 32, 128, 128, norm_layer),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)),
            SepInceptionBlock3D(832, 256, 160, 320, 32, 128, 128, norm_layer),
            SepInceptionBlock3D(832, 384, 192, 384, 48, 128, 128, norm_layer),
        )
        self.avgpool = nn.AvgPool3d(kernel_size=(2, 7, 7), stride=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Conv3d(1024, num_classes, kernel_size=1, stride=1, bias=True),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.classifier(x)
        x = torch.mean(x, dim=(2, 3, 4))
        return x


class S3D_Weights(WeightsEnum):
    KINETICS400_V1 = Weights(
Vasilis Vryniotis's avatar
Vasilis Vryniotis committed
156
        url="https://download.pytorch.org/models/s3d-d76dad2f.pth",
157
158
159
160
161
162
163
164
165
        transforms=partial(
            VideoClassification,
            crop_size=(224, 224),
            resize_size=(256, 256),
        ),
        meta={
            "min_size": (224, 224),
            "min_temporal_size": 14,
            "categories": _KINETICS400_CATEGORIES,
Vasilis Vryniotis's avatar
Vasilis Vryniotis committed
166
            "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification#s3d",
167
            "_docs": (
Vasilis Vryniotis's avatar
Vasilis Vryniotis committed
168
                "The weights aim to approximate the accuracy of the paper. The accuracies are estimated on clip-level "
169
170
171
172
173
                "with parameters `frame_rate=15`, `clips_per_video=1`, and `clip_len=128`."
            ),
            "num_params": 8320048,
            "_metrics": {
                "Kinetics-400": {
Vasilis Vryniotis's avatar
Vasilis Vryniotis committed
174
175
                    "acc@1": 68.368,
                    "acc@5": 88.050,
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
                }
            },
        },
    )
    DEFAULT = KINETICS400_V1


@register_model()
def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwargs: Any) -> S3D:
    """Construct Separable 3D CNN model.

    Reference: `Rethinking Spatiotemporal Feature Learning <https://arxiv.org/abs/1712.04851>`__.

    Args:
        weights (:class:`~torchvision.models.video.S3D_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.video.S3D_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.video.S3D`` base class.
            Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/video/s3d.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.video.S3D_Weights
        :members:
    """
    weights = S3D_Weights.verify(weights)

    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = S3D(**kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

    return model