resnet.py 16.2 KB
Newer Older
1
from functools import partial
2
from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
3

4
5
6
import torch.nn as nn
from torch import Tensor

7
from ...transforms._presets import VideoClassification
8
from ...utils import _log_api_usage_once
9
from .._api import register_model, Weights, WeightsEnum
10
from .._meta import _KINETICS400_CATEGORIES
11
from .._utils import _ovewrite_named_param, handle_legacy_interface
12
13


14
15
16
17
18
19
20
21
22
__all__ = [
    "VideoResNet",
    "R3D_18_Weights",
    "MC3_18_Weights",
    "R2Plus1D_18_Weights",
    "r3d_18",
    "mc3_18",
    "r2plus1d_18",
]
23
24
25


class Conv3DSimple(nn.Conv3d):
26
    def __init__(
27
        self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
28
    ) -> None:
29

30
        super().__init__(
31
32
33
34
35
            in_channels=in_planes,
            out_channels=out_planes,
            kernel_size=(3, 3, 3),
            stride=stride,
            padding=padding,
36
37
            bias=False,
        )
38
39

    @staticmethod
40
    def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
41
        return stride, stride, stride
42
43
44


class Conv2Plus1D(nn.Sequential):
45
    def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None:
46
        super().__init__(
47
48
49
50
51
52
53
54
            nn.Conv3d(
                in_planes,
                midplanes,
                kernel_size=(1, 3, 3),
                stride=(1, stride, stride),
                padding=(0, padding, padding),
                bias=False,
            ),
55
56
            nn.BatchNorm3d(midplanes),
            nn.ReLU(inplace=True),
57
58
59
60
            nn.Conv3d(
                midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False
            ),
        )
61
62

    @staticmethod
63
    def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
64
        return stride, stride, stride
65
66
67


class Conv3DNoTemporal(nn.Conv3d):
68
    def __init__(
69
        self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
70
    ) -> None:
71

72
        super().__init__(
73
74
75
76
77
            in_channels=in_planes,
            out_channels=out_planes,
            kernel_size=(1, 3, 3),
            stride=(1, stride, stride),
            padding=(0, padding, padding),
78
79
            bias=False,
        )
80
81

    @staticmethod
82
    def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
83
        return 1, stride, stride
84
85
86
87
88
89


class BasicBlock(nn.Module):

    expansion = 1

90
91
92
93
94
95
96
97
    def __init__(
        self,
        inplanes: int,
        planes: int,
        conv_builder: Callable[..., nn.Module],
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
    ) -> None:
98
99
        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

100
        super().__init__()
101
        self.conv1 = nn.Sequential(
102
            conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
103
        )
104
        self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes))
105
106
107
108
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

109
    def forward(self, x: Tensor) -> Tensor:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

126
127
128
129
130
131
132
133
    def __init__(
        self,
        inplanes: int,
        planes: int,
        conv_builder: Callable[..., nn.Module],
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
    ) -> None:
134

135
        super().__init__()
136
137
138
139
        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

        # 1x1x1
        self.conv1 = nn.Sequential(
140
            nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
141
142
143
        )
        # Second kernel
        self.conv2 = nn.Sequential(
144
            conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
145
146
147
148
149
        )

        # 1x1x1
        self.conv3 = nn.Sequential(
            nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
150
            nn.BatchNorm3d(planes * self.expansion),
151
152
153
154
155
        )
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

156
    def forward(self, x: Tensor) -> Tensor:
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class BasicStem(nn.Sequential):
173
174
    """The default conv-batchnorm-relu stem"""

175
    def __init__(self) -> None:
176
        super().__init__(
177
            nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False),
178
            nn.BatchNorm3d(64),
179
180
            nn.ReLU(inplace=True),
        )
181
182
183


class R2Plus1dStem(nn.Sequential):
184
185
    """R(2+1)D stem is different than the default one as it uses separated 3D convolution"""

186
    def __init__(self) -> None:
187
        super().__init__(
188
            nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False),
189
190
            nn.BatchNorm3d(45),
            nn.ReLU(inplace=True),
191
            nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False),
192
            nn.BatchNorm3d(64),
193
194
            nn.ReLU(inplace=True),
        )
195
196
197


class VideoResNet(nn.Module):
198
199
200
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
201
        conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
202
203
204
205
206
        layers: List[int],
        stem: Callable[..., nn.Module],
        num_classes: int = 400,
        zero_init_residual: bool = False,
    ) -> None:
207
208
209
        """Generic resnet video generator.

        Args:
210
211
212
            block (Type[Union[BasicBlock, Bottleneck]]): resnet building block
            conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator
                function for each layer
213
            layers (List[int]): number of blocks per layer
214
            stem (Callable[..., nn.Module]): module specifying the ResNet stem.
215
216
217
            num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
            zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
        """
218
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
219
        _log_api_usage_once(self)
220
221
222
223
224
225
226
227
228
229
230
231
232
        self.inplanes = 64

        self.stem = stem()

        self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # init weights
233
234
235
236
237
238
239
240
241
242
243
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
244
245
246
247

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
248
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[union-attr, arg-type]
249

250
    def forward(self, x: Tensor) -> Tensor:
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        x = self.stem(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        # Flatten the layer to fc
        x = x.flatten(1)
        x = self.fc(x)

        return x

265
266
267
268
269
270
    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]],
        planes: int,
        blocks: int,
271
        stride: int = 1,
272
    ) -> nn.Sequential:
273
274
275
276
277
        downsample = None

        if stride != 1 or self.inplanes != planes * block.expansion:
            ds_stride = conv_builder.get_downsample_stride(stride)
            downsample = nn.Sequential(
278
279
                nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False),
                nn.BatchNorm3d(planes * block.expansion),
280
281
282
283
284
285
286
287
288
289
290
            )
        layers = []
        layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))

        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, conv_builder))

        return nn.Sequential(*layers)


291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def _video_resnet(
    block: Type[Union[BasicBlock, Bottleneck]],
    conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
    layers: List[int],
    stem: Callable[..., nn.Module],
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> VideoResNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = VideoResNet(block, conv_makers, layers, stem, **kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
307
308
309
310

    return model


311
312
313
314
_COMMON_META = {
    "min_size": (1, 1),
    "categories": _KINETICS400_CATEGORIES,
    "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
315
316
317
318
    "_docs": (
        "The weights reproduce closely the accuracy of the paper. The accuracies are estimated on video-level "
        "with parameters `frame_rate=15`, `clips_per_video=5`, and `clip_len=16`."
    ),
319
320
321
322
323
324
325
326
327
328
}


class R3D_18_Weights(WeightsEnum):
    KINETICS400_V1 = Weights(
        url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
        transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
        meta={
            **_COMMON_META,
            "num_params": 33371472,
329
330
            "_metrics": {
                "Kinetics-400": {
331
332
                    "acc@1": 63.200,
                    "acc@5": 83.479,
333
                }
334
            },
335
336
337
338
339
340
341
342
343
344
345
346
        },
    )
    DEFAULT = KINETICS400_V1


class MC3_18_Weights(WeightsEnum):
    KINETICS400_V1 = Weights(
        url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
        transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
        meta={
            **_COMMON_META,
            "num_params": 11695440,
347
348
            "_metrics": {
                "Kinetics-400": {
349
350
                    "acc@1": 63.960,
                    "acc@5": 84.130,
351
                }
352
            },
353
354
355
356
357
358
359
360
361
362
363
364
        },
    )
    DEFAULT = KINETICS400_V1


class R2Plus1D_18_Weights(WeightsEnum):
    KINETICS400_V1 = Weights(
        url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
        transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
        meta={
            **_COMMON_META,
            "num_params": 31505325,
365
366
            "_metrics": {
                "Kinetics-400": {
367
368
                    "acc@1": 67.463,
                    "acc@5": 86.175,
369
                }
370
            },
371
372
373
374
375
        },
    )
    DEFAULT = KINETICS400_V1


376
@register_model()
377
378
@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
379
    """Construct 18 layer Resnet3D model.
380

381
382
    .. betastatus:: video module

383
    Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
384

385
386
387
388
389
390
391
392
393
394
395
396
397
398
    Args:
        weights (:class:`~torchvision.models.video.R3D_18_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.video.R3D_18_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
            Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.video.R3D_18_Weights
        :members:
399
    """
400
    weights = R3D_18_Weights.verify(weights)
401

402
    return _video_resnet(
403
404
405
406
407
        BasicBlock,
        [Conv3DSimple] * 4,
        [2, 2, 2, 2],
        BasicStem,
        weights,
408
409
410
        progress,
        **kwargs,
    )
411
412


413
@register_model()
414
415
@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
416
    """Construct 18 layer Mixed Convolution network as in
417

418
419
    .. betastatus:: video module

420
    Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
421

422
423
424
425
426
427
428
429
430
431
432
433
434
435
    Args:
        weights (:class:`~torchvision.models.video.MC3_18_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.video.MC3_18_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
            Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.video.MC3_18_Weights
        :members:
436
    """
437
438
    weights = MC3_18_Weights.verify(weights)

439
    return _video_resnet(
440
441
442
443
444
        BasicBlock,
        [Conv3DSimple] + [Conv3DNoTemporal] * 3,  # type: ignore[list-item]
        [2, 2, 2, 2],
        BasicStem,
        weights,
445
446
447
        progress,
        **kwargs,
    )
448
449


450
@register_model()
451
452
@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
453
    """Construct 18 layer deep R(2+1)D network as in
454

455
456
    .. betastatus:: video module

457
    Reference: `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`__.
458

459
460
461
462
463
464
465
466
467
468
469
470
471
472
    Args:
        weights (:class:`~torchvision.models.video.R2Plus1D_18_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.video.R2Plus1D_18_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.video.resnet.VideoResNet`` base class.
            Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/video/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.video.R2Plus1D_18_Weights
        :members:
473
    """
474
475
    weights = R2Plus1D_18_Weights.verify(weights)

476
    return _video_resnet(
477
478
479
480
481
        BasicBlock,
        [Conv2Plus1D] * 4,
        [2, 2, 2, 2],
        R2Plus1dStem,
        weights,
482
483
484
        progress,
        **kwargs,
    )
485
486
487
488
489
490
491
492
493
494
495
496
497


# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "r3d_18": R3D_18_Weights.KINETICS400_V1.url,
        "mc3_18": MC3_18_Weights.KINETICS400_V1.url,
        "r2plus1d_18": R2Plus1D_18_Weights.KINETICS400_V1.url,
    }
)