resnet.py 13.3 KB
Newer Older
1
from functools import partial
2
from typing import Tuple, Optional, Callable, List, Sequence, Type, Any, Union
3

4
5
6
import torch.nn as nn
from torch import Tensor

7
from ...transforms._presets import VideoClassification
8
from ...utils import _log_api_usage_once
9
10
11
from .._api import WeightsEnum, Weights
from .._meta import _KINETICS400_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
12
13


14
15
16
17
18
19
20
21
22
__all__ = [
    "VideoResNet",
    "R3D_18_Weights",
    "MC3_18_Weights",
    "R2Plus1D_18_Weights",
    "r3d_18",
    "mc3_18",
    "r2plus1d_18",
]
23
24
25


class Conv3DSimple(nn.Conv3d):
26
    def __init__(
27
        self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
28
    ) -> None:
29

30
        super().__init__(
31
32
33
34
35
            in_channels=in_planes,
            out_channels=out_planes,
            kernel_size=(3, 3, 3),
            stride=stride,
            padding=padding,
36
37
            bias=False,
        )
38
39

    @staticmethod
40
    def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
41
        return stride, stride, stride
42
43
44


class Conv2Plus1D(nn.Sequential):
45
    def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None:
46
        super().__init__(
47
48
49
50
51
52
53
54
            nn.Conv3d(
                in_planes,
                midplanes,
                kernel_size=(1, 3, 3),
                stride=(1, stride, stride),
                padding=(0, padding, padding),
                bias=False,
            ),
55
56
            nn.BatchNorm3d(midplanes),
            nn.ReLU(inplace=True),
57
58
59
60
            nn.Conv3d(
                midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False
            ),
        )
61
62

    @staticmethod
63
    def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
64
        return stride, stride, stride
65
66
67


class Conv3DNoTemporal(nn.Conv3d):
68
    def __init__(
69
        self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1
70
    ) -> None:
71

72
        super().__init__(
73
74
75
76
77
            in_channels=in_planes,
            out_channels=out_planes,
            kernel_size=(1, 3, 3),
            stride=(1, stride, stride),
            padding=(0, padding, padding),
78
79
            bias=False,
        )
80
81

    @staticmethod
82
    def get_downsample_stride(stride: int) -> Tuple[int, int, int]:
83
        return 1, stride, stride
84
85
86
87
88
89


class BasicBlock(nn.Module):

    expansion = 1

90
91
92
93
94
95
96
97
    def __init__(
        self,
        inplanes: int,
        planes: int,
        conv_builder: Callable[..., nn.Module],
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
    ) -> None:
98
99
        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

100
        super().__init__()
101
        self.conv1 = nn.Sequential(
102
            conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
103
        )
104
        self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes))
105
106
107
108
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

109
    def forward(self, x: Tensor) -> Tensor:
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

126
127
128
129
130
131
132
133
    def __init__(
        self,
        inplanes: int,
        planes: int,
        conv_builder: Callable[..., nn.Module],
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
    ) -> None:
134

135
        super().__init__()
136
137
138
139
        midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes)

        # 1x1x1
        self.conv1 = nn.Sequential(
140
            nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
141
142
143
        )
        # Second kernel
        self.conv2 = nn.Sequential(
144
            conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True)
145
146
147
148
149
        )

        # 1x1x1
        self.conv3 = nn.Sequential(
            nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False),
150
            nn.BatchNorm3d(planes * self.expansion),
151
152
153
154
155
        )
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

156
    def forward(self, x: Tensor) -> Tensor:
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class BasicStem(nn.Sequential):
173
174
    """The default conv-batchnorm-relu stem"""

175
    def __init__(self) -> None:
176
        super().__init__(
177
            nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False),
178
            nn.BatchNorm3d(64),
179
180
            nn.ReLU(inplace=True),
        )
181
182
183


class R2Plus1dStem(nn.Sequential):
184
185
    """R(2+1)D stem is different than the default one as it uses separated 3D convolution"""

186
    def __init__(self) -> None:
187
        super().__init__(
188
            nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False),
189
190
            nn.BatchNorm3d(45),
            nn.ReLU(inplace=True),
191
            nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False),
192
            nn.BatchNorm3d(64),
193
194
            nn.ReLU(inplace=True),
        )
195
196
197


class VideoResNet(nn.Module):
198
199
200
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
201
        conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
202
203
204
205
206
        layers: List[int],
        stem: Callable[..., nn.Module],
        num_classes: int = 400,
        zero_init_residual: bool = False,
    ) -> None:
207
208
209
        """Generic resnet video generator.

        Args:
210
211
212
            block (Type[Union[BasicBlock, Bottleneck]]): resnet building block
            conv_makers (List[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]]): generator
                function for each layer
213
            layers (List[int]): number of blocks per layer
214
            stem (Callable[..., nn.Module]): module specifying the ResNet stem.
215
216
217
            num_classes (int, optional): Dimension of the final FC layer. Defaults to 400.
            zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False.
        """
218
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
219
        _log_api_usage_once(self)
220
221
222
223
224
225
226
227
228
229
230
231
232
        self.inplanes = 64

        self.stem = stem()

        self.layer1 = self._make_layer(block, conv_makers[0], 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, conv_makers[1], 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, conv_makers[2], 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, conv_makers[3], 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # init weights
233
234
235
236
237
238
239
240
241
242
243
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
244
245
246
247

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
248
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[union-attr, arg-type]
249

250
    def forward(self, x: Tensor) -> Tensor:
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        x = self.stem(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        # Flatten the layer to fc
        x = x.flatten(1)
        x = self.fc(x)

        return x

265
266
267
268
269
270
    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]],
        planes: int,
        blocks: int,
271
        stride: int = 1,
272
    ) -> nn.Sequential:
273
274
275
276
277
        downsample = None

        if stride != 1 or self.inplanes != planes * block.expansion:
            ds_stride = conv_builder.get_downsample_stride(stride)
            downsample = nn.Sequential(
278
279
                nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False),
                nn.BatchNorm3d(planes * block.expansion),
280
281
282
283
284
285
286
287
288
289
290
            )
        layers = []
        layers.append(block(self.inplanes, planes, conv_builder, stride, downsample))

        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, conv_builder))

        return nn.Sequential(*layers)


291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def _video_resnet(
    block: Type[Union[BasicBlock, Bottleneck]],
    conv_makers: Sequence[Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]]],
    layers: List[int],
    stem: Callable[..., nn.Module],
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> VideoResNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = VideoResNet(block, conv_makers, layers, stem, **kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
307
308
309
310

    return model


311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
_COMMON_META = {
    "min_size": (1, 1),
    "categories": _KINETICS400_CATEGORIES,
    "recipe": "https://github.com/pytorch/vision/tree/main/references/video_classification",
}


class R3D_18_Weights(WeightsEnum):
    KINETICS400_V1 = Weights(
        url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
        transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
        meta={
            **_COMMON_META,
            "num_params": 33371472,
            "acc@1": 52.75,
            "acc@5": 75.45,
        },
    )
    DEFAULT = KINETICS400_V1


class MC3_18_Weights(WeightsEnum):
    KINETICS400_V1 = Weights(
        url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
        transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
        meta={
            **_COMMON_META,
            "num_params": 11695440,
            "acc@1": 53.90,
            "acc@5": 76.29,
        },
    )
    DEFAULT = KINETICS400_V1


class R2Plus1D_18_Weights(WeightsEnum):
    KINETICS400_V1 = Weights(
        url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
        transforms=partial(VideoClassification, crop_size=(112, 112), resize_size=(128, 171)),
        meta={
            **_COMMON_META,
            "num_params": 31505325,
            "acc@1": 57.50,
            "acc@5": 78.81,
        },
    )
    DEFAULT = KINETICS400_V1


@handle_legacy_interface(weights=("pretrained", R3D_18_Weights.KINETICS400_V1))
def r3d_18(*, weights: Optional[R3D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
362
363
364
365
    """Construct 18 layer Resnet3D model as in
    https://arxiv.org/abs/1711.11248

    Args:
366
        weights (R3D_18_Weights, optional): The pretrained weights for the model
367
368
369
        progress (bool): If True, displays a progress bar of the download to stderr

    Returns:
370
        VideoResNet: R3D-18 network
371
    """
372
    weights = R3D_18_Weights.verify(weights)
373

374
    return _video_resnet(
375
376
377
378
379
        BasicBlock,
        [Conv3DSimple] * 4,
        [2, 2, 2, 2],
        BasicStem,
        weights,
380
381
382
        progress,
        **kwargs,
    )
383
384


385
386
@handle_legacy_interface(weights=("pretrained", MC3_18_Weights.KINETICS400_V1))
def mc3_18(*, weights: Optional[MC3_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
387
388
389
390
    """Constructor for 18 layer Mixed Convolution network as in
    https://arxiv.org/abs/1711.11248

    Args:
391
        weights (MC3_18_Weights, optional): The pretrained weights for the model
392
393
394
        progress (bool): If True, displays a progress bar of the download to stderr

    Returns:
395
        VideoResNet: MC3 Network definition
396
    """
397
398
    weights = MC3_18_Weights.verify(weights)

399
    return _video_resnet(
400
401
402
403
404
        BasicBlock,
        [Conv3DSimple] + [Conv3DNoTemporal] * 3,  # type: ignore[list-item]
        [2, 2, 2, 2],
        BasicStem,
        weights,
405
406
407
        progress,
        **kwargs,
    )
408
409


410
411
@handle_legacy_interface(weights=("pretrained", R2Plus1D_18_Weights.KINETICS400_V1))
def r2plus1d_18(*, weights: Optional[R2Plus1D_18_Weights] = None, progress: bool = True, **kwargs: Any) -> VideoResNet:
412
413
414
415
    """Constructor for the 18 layer deep R(2+1)D network as in
    https://arxiv.org/abs/1711.11248

    Args:
416
        weights (R2Plus1D_18_Weights, optional): The pretrained weights for the model
417
418
419
        progress (bool): If True, displays a progress bar of the download to stderr

    Returns:
420
        VideoResNet: R(2+1)D-18 network
421
    """
422
423
    weights = R2Plus1D_18_Weights.verify(weights)

424
    return _video_resnet(
425
426
427
428
429
        BasicBlock,
        [Conv2Plus1D] * 4,
        [2, 2, 2, 2],
        R2Plus1dStem,
        weights,
430
431
432
        progress,
        **kwargs,
    )