r2plus1d.py 1.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch.nn as nn

from ._utils import Conv2Plus1D
from .video_stems import get_r2plus1d_stem
from .video_trunk import VideoTrunkBuilder, BasicBlock, Bottleneck


__all__ = ["r2plus1d_18"]


def _r2plus1d(model_depth, use_pool1=False, **kwargs):
    """Constructor for R(2+1)D network as described in
    https://arxiv.org/abs/1711.11248

    Args:
        model_depth (int): Depth of the model - standard resnet depths apply
        use_pool1 (bool, optional): Should we use the pooling layer? Defaults to False
    Returns:
        nn.Module: An R(2+1)D video backbone
    """
    convs = [Conv2Plus1D] * 4
    if model_depth < 50:
        block = BasicBlock
    else:
        block = Bottleneck

    model = VideoTrunkBuilder(
        block=block, conv_makers=convs, model_depth=model_depth,
        stem=get_r2plus1d_stem(use_pool1), **kwargs)
    return model


def r2plus1d_18(use_pool1=False, **kwargs):
    """Constructor for the 18 layer deep R(2+1)D network as in
    https://arxiv.org/abs/1711.11248

    Args:
        use_pool1 (bool, optional): Include pooling in the resnet stem. Defaults to False.

    Returns:
        nn.Module: R(2+1)D-18 network
    """
    return _r2plus1d(18, use_pool1, **kwargs)