resnet.py 13.5 KB
Newer Older
1
from functools import partial
2
from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union
3

4
5
6
import torch.nn as nn
from torch import Tensor

7
from ...transforms._presets import VideoClassification
8
from ...utils import _log_api_usage_once
9
10
11
from .._api import WeightsEnum, Weights
from .._meta import _KINETICS400_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
12
13


14
15
16
17
18
19
20
21
22
__all__ = [
    "VideoResNet",
    "R3D_18_Weights",
    "MC3_18_Weights",
    "R2Plus1D_18_Weights",
    "r3d_18",
    "mc3_18",
    "r2plus1d_18",
]
23
24
25


class Conv3DSimple(nn.Conv3d):
26
    def __init__(
27
        self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
28
    ) -> None:
29

30
        super().__init__(
31
32
33
34
35
            in_channels=in_planes,
            out_channels=out_planes,
            kernel_size=(3, 3, 3),
            stride=stride,
            padding=padding,
36
37
            bias=False,
        )
38
39

    @staticmethod
40
    def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
41
        return stride, stride, stride
42
43
44


class Conv2Plus1D(nn.Sequential):
45
    def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None:
46
        super().__init__(
47
48
49
50
51
52
53
54
            nn.Conv3d(
                in_planes,
                midplanes,
                kernel_size=(1, 3, 3),
                stride=(1, stride, stride),
                padding=(0, padding, padding),
                bias=False,
            ),
55
56
            nn.BatchNorm3d(midplanes),
            nn.ReLU(inplace=True),
57
58
59
60
            nn.Conv3d(
                midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False
            ),
        )
61
62

    @staticmethod
63
    def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
64
        return stride, stride, stride
65
66
67


class Conv3DNoTemporal(nn.Conv3d):
68
    def __init__(
69
        self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
70
    ) -> None:
71

72
        super().__init__(
73
74
75
76
77
            in_channels=in_planes,
            out_channels=out_planes,
            kernel_size=(1, 3, 3),
            stride=(1, stride, stride),
            padding=(0, padding, padding),
78
79
            bias=False,
        )
80
81

    @staticmethod
82
    def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
83
        return 1, stride, stride
84
85
86
87
88
89


class BasicBlock(nn.Module):

    expansion = 1

90
91
92
93
94
95
96
97
    def __init__(
        self,
        inplanes: int,
        planes: int,
        conv_builder: Callable[..., nn.Module],
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
    ) -> None:
98
99
        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

100
        super().__init__()
101
        self.conv1 = nn.Sequential(
102
            conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
103
        )
104
        self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes))
105
106
107
108
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

109
    def forward(self, x: Tensor) -> Tensor:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

126
127
128
129
130
131
132
133
    def __init__(
        self,
        inplanes: int,
        planes: int,
        conv_builder: Callable[..., nn.Module],
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
    ) -> None:
134

135
        super().__init__()
136
137
138
139
        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

        # 1x1x1
        self.conv1 = nn.Sequential(
140
            nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
141
142
143
        )
        # Second kernel
        self.conv2 = nn.Sequential(
144
            conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
145
146
147
148
149
        )

        # 1x1x1
        self.conv3 = nn.Sequential(
            nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
150
            nn.BatchNorm3d(planes * self.expansion),
151
152
153
154
155
        )
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

156
    def forward(self, x: Tensor) -> Tensor:
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class BasicStem(nn.Sequential):
173
174
    """The default conv-batchnorm-relu stem"""

175
    def __init__(self) -> None:
176
        super().__init__(
177
            nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False),
178
            nn.BatchNorm3d(64),
179
180
            nn.ReLU(inplace=True),
        )
181
182
183


class R2Plus1dStem(nn.Sequential):
184
185
    """R(2+1)D stem is different than the default one as it uses separated 3D convolution"""

186
    def __init__(self) -> None:
187
        super().__init__(
188
            nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False),
189
190
            nn.BatchNorm3d(45),
            nn.ReLU(inplace=True),
191
            nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False),
192
            nn.BatchNorm3d(64),
193
194
            nn.ReLU(inplace=True),
        )
195
196
197


class VideoResNet(nn.Module):
198
199
200
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
201
        conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
202
203
204
205
206
        layers: List[int],
        stem: Callable[..., nn.Module],
        num_classes: int = 400,
        zero_init_residual: bool = False,
    ) -> None:
207
208
209
        """Generic resnet video generator.

        Args:
210
211
212
            block (Type[Union[BasicBlock, Bottleneck]]): resnet building block
            conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator
                function for each layer
213
            layers (List[int]): number of blocks per layer
214
            stem (Callable[..., nn.Module]): module specifying the ResNet stem.
215
216
217
            num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
            zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
        """
218
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
219
        _log_api_usage_once(self)
220
221
222
223
224
225
226
227
228
229
230
231
232
        self.inplanes = 64

        self.stem = stem()

        self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # init weights
233
234
235
236
237
238
239
240
241
242
243
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
244
245
246
247

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
248
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[union-attr, arg-type]
249

250
    def forward(self, x: Tensor) -> Tensor:
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        x = self.stem(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        # Flatten the layer to fc
        x = x.flatten(1)
        x = self.fc(x)

        return x

265
266
267
268
269
270
    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]],
        planes: int,
        blocks: int,
271
        stride: int = 1,
272
    ) -> nn.Sequential:
273
274
275
276
277
        downsample = None

        if stride != 1 or self.inplanes != planes * block.expansion:
            ds_stride = conv_builder.get_downsample_stride(stride)
            downsample = nn.Sequential(
278
279
                nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False),
                nn.BatchNorm3d(planes * block.expansion),
280
281
282
283
284
285
286
287
288
289
290
            )
        layers = []
        layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))

        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, conv_builder))

        return nn.Sequential(*layers)


291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def _video_resnet(
    block: Type[Union[BasicBlock, Bottleneck]],
    conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
    layers: List[int],
    stem: Callable[..., nn.Module],
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> VideoResNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = VideoResNet(block, conv_makers, layers, stem, **kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
307
308
309
310

    return model


311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
_COMMON_META = {
    "task": "video_classification",
    "size": (112, 112),
    "min_size": (1, 1),
    "categories": _KINETICS400_CATEGORIES,
    "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
}


class R3D_18_Weights(WeightsEnum):
    KINETICS400_V1 = Weights(
        url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
        transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
        meta={
            **_COMMON_META,
            "architecture": "R3D",
            "num_params": 33371472,
            "acc@1": 52.75,
            "acc@5": 75.45,
        },
    )
    DEFAULT = KINETICS400_V1


class MC3_18_Weights(WeightsEnum):
    KINETICS400_V1 = Weights(
        url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
        transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
        meta={
            **_COMMON_META,
            "architecture": "MC3",
            "num_params": 11695440,
            "acc@1": 53.90,
            "acc@5": 76.29,
        },
    )
    DEFAULT = KINETICS400_V1


class R2Plus1D_18_Weights(WeightsEnum):
    KINETICS400_V1 = Weights(
        url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
        transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
        meta={
            **_COMMON_META,
            "architecture": "R(2+1)D",
            "num_params": 31505325,
            "acc@1": 57.50,
            "acc@5": 78.81,
        },
    )
    DEFAULT = KINETICS400_V1


@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
367
368
369
370
    """Construct 18 layer Resnet3D model as in
    https://arxiv.org/abs/1711.11248

    Args:
371
        weights (R3D_18_Weights, optional): The pretrained weights for the model
372
373
374
        progress (bool): If True, displays a progress bar of the download to stderr

    Returns:
375
        VideoResNet: R3D-18 network
376
    """
377
    weights = R3D_18_Weights.verify(weights)
378

379
    return _video_resnet(
380
381
382
383
384
        BasicBlock,
        [Conv3DSimple] * 4,
        [2, 2, 2, 2],
        BasicStem,
        weights,
385
386
387
        progress,
        **kwargs,
    )
388
389


390
391
@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
392
393
394
395
    """Constructor for 18 layer Mixed Convolution network as in
    https://arxiv.org/abs/1711.11248

    Args:
396
        weights (MC3_18_Weights, optional): The pretrained weights for the model
397
398
399
        progress (bool): If True, displays a progress bar of the download to stderr

    Returns:
400
        VideoResNet: MC3 Network definition
401
    """
402
403
    weights = MC3_18_Weights.verify(weights)

404
    return _video_resnet(
405
406
407
408
409
        BasicBlock,
        [Conv3DSimple] + [Conv3DNoTemporal] * 3,  # type: ignore[list-item]
        [2, 2, 2, 2],
        BasicStem,
        weights,
410
411
412
        progress,
        **kwargs,
    )
413
414


415
416
@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
417
418
419
420
    """Constructor for the 18 layer deep R(2+1)D network as in
    https://arxiv.org/abs/1711.11248

    Args:
421
        weights (R2Plus1D_18_Weights, optional): The pretrained weights for the model
422
423
424
        progress (bool): If True, displays a progress bar of the download to stderr

    Returns:
425
        VideoResNet: R(2+1)D-18 network
426
    """
427
428
    weights = R2Plus1D_18_Weights.verify(weights)

429
    return _video_resnet(
430
431
432
433
434
        BasicBlock,
        [Conv2Plus1D] * 4,
        [2, 2, 2, 2],
        R2Plus1dStem,
        weights,
435
436
437
        progress,
        **kwargs,
    )