r3d.py 1.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch.nn as nn

from ._utils import Conv3DSimple
from .video_stems import get_default_stem
from .video_trunk import VideoTrunkBuilder, BasicBlock, Bottleneck

__all__ = ["r3d_18"]


def _r3d(model_depth, use_pool1=False, **kwargs):
    """Constructor of a r3d network as in
    https://arxiv.org/abs/1711.11248

    Args:
        model_depth (int): resnet trunk depth
        use_pool1 (bool, optional): Add pooling layer to the stem. Defaults to False

    Returns:
        nn.Module: R3D network trunk
    """

    conv_makers = [Conv3DSimple] * 4
    if model_depth < 50:
        block = BasicBlock
    else:
        block = Bottleneck

    model = VideoTrunkBuilder(block=block, conv_makers=conv_makers, model_depth=model_depth,
                              stem=get_default_stem(use_pool1=use_pool1), **kwargs)
    return model


def r3d_18(use_pool1=False, **kwargs):
    """Construct 18 layer Resnet3D model as in
    https://arxiv.org/abs/1711.11248

    Args:
        use_pool1 (bool, optional): Include pooling in resnet stem. Defaults to False.

    Returns:
        nn.Module: R3D-18 network
    """
    return _r3d(18, use_pool1, **kwargs)