test_common.py 4.68 KB
Newer Older
unknown's avatar
unknown committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import pytest
import torch
import torch.nn as nn
from mmcv.utils import assert_params_all_zeros

from mmaction.models.common import (LFB, TAM, Conv2plus1d, ConvAudio,
                                    DividedSpatialAttentionWithNorm,
                                    DividedTemporalAttentionWithNorm,
                                    FFNWithNorm, SubBatchNorm3D)


def test_conv2plus1d():
    with pytest.raises(AssertionError):
        # Length of kernel size, stride and padding must be the same
        Conv2plus1d(3, 8, (2, 2))

    conv_2plus1d = Conv2plus1d(3, 8, 2)
    conv_2plus1d.init_weights()

    assert torch.equal(conv_2plus1d.bn_s.weight,
                       torch.ones_like(conv_2plus1d.bn_s.weight))
    assert torch.equal(conv_2plus1d.bn_s.bias,
                       torch.zeros_like(conv_2plus1d.bn_s.bias))

    x = torch.rand(1, 3, 8, 256, 256)
    output = conv_2plus1d(x)
    assert output.shape == torch.Size([1, 8, 7, 255, 255])


def test_conv_audio():
    conv_audio = ConvAudio(3, 8, 3)
    conv_audio.init_weights()

    x = torch.rand(1, 3, 8, 8)
    output = conv_audio(x)
    assert output.shape == torch.Size([1, 16, 8, 8])

    conv_audio_sum = ConvAudio(3, 8, 3, op='sum')
    output = conv_audio_sum(x)
    assert output.shape == torch.Size([1, 8, 8, 8])


def test_divided_temporal_attention_with_norm():
    _cfg = dict(embed_dims=768, num_heads=12, num_frames=8)
    divided_temporal_attention = DividedTemporalAttentionWithNorm(**_cfg)
    assert isinstance(divided_temporal_attention.norm, nn.LayerNorm)
    assert assert_params_all_zeros(divided_temporal_attention.temporal_fc)

    x = torch.rand(1, 1 + 8 * 14 * 14, 768)
    output = divided_temporal_attention(x)
    assert output.shape == torch.Size([1, 1 + 8 * 14 * 14, 768])


def test_divided_spatial_attention_with_norm():
    _cfg = dict(embed_dims=512, num_heads=8, num_frames=4, dropout_layer=None)
    divided_spatial_attention = DividedSpatialAttentionWithNorm(**_cfg)
    assert isinstance(divided_spatial_attention.dropout_layer, nn.Identity)
    assert isinstance(divided_spatial_attention.norm, nn.LayerNorm)

    x = torch.rand(1, 1 + 4 * 14 * 14, 512)
    output = divided_spatial_attention(x)
    assert output.shape == torch.Size([1, 1 + 4 * 14 * 14, 512])


def test_ffn_with_norm():
    _cfg = dict(
        embed_dims=256, feedforward_channels=256 * 2, norm_cfg=dict(type='LN'))
    ffn_with_norm = FFNWithNorm(**_cfg)
    assert isinstance(ffn_with_norm.norm, nn.LayerNorm)

    x = torch.rand(1, 1 + 4 * 14 * 14, 256)
    output = ffn_with_norm(x)
    assert output.shape == torch.Size([1, 1 + 4 * 14 * 14, 256])


def test_TAM():
    """test TAM."""
    with pytest.raises(AssertionError):
        # alpha must be a positive integer
        TAM(16, 8, alpha=0, beta=4)

    with pytest.raises(AssertionError):
        # beta must be a positive integer
        TAM(16, 8, alpha=2, beta=0)

    with pytest.raises(AssertionError):
        # the channels number of x should be equal to self.in_channels of TAM
        tam = TAM(16, 8)
        x = torch.rand(64, 8, 112, 112)
        tam(x)

    tam = TAM(16, 8)
    x = torch.rand(32, 16, 112, 112)
    output = tam(x)
    assert output.shape == torch.Size([32, 16, 112, 112])


def test_LFB():
    """test LFB."""
    with pytest.raises(ValueError):
        LFB(lfb_prefix_path='./_non_exist_path')

    lfb_prefix_path = osp.normpath(
        osp.join(osp.dirname(__file__), '../data/lfb'))

    with pytest.raises(AssertionError):
        LFB(lfb_prefix_path=lfb_prefix_path, dataset_modes=100)

    with pytest.raises(ValueError):
        LFB(lfb_prefix_path=lfb_prefix_path, device='ceph')

    # load on cpu
    lfb_cpu = LFB(
        lfb_prefix_path=lfb_prefix_path,
        max_num_sampled_feat=5,
        window_size=60,
        lfb_channels=16,
        dataset_modes=('unittest'),
        device='cpu')

    lt_feat_cpu = lfb_cpu['video_1,930']
    assert lt_feat_cpu.shape == (5 * 60, 16)
    assert len(lfb_cpu) == 1

    # load on lmdb
    lfb_lmdb = LFB(
        lfb_prefix_path=lfb_prefix_path,
        max_num_sampled_feat=3,
        window_size=30,
        lfb_channels=16,
        dataset_modes=('unittest'),
        device='lmdb',
        lmdb_map_size=1e6)
    lt_feat_lmdb = lfb_lmdb['video_1,930']
    assert lt_feat_lmdb.shape == (3 * 30, 16)


def test_SubBatchNorm3D():
    _cfg = dict(num_splits=2)
    num_features = 4
    sub_batchnorm_3d = SubBatchNorm3D(num_features, **_cfg)
    assert sub_batchnorm_3d.bn.num_features == num_features
    assert sub_batchnorm_3d.split_bn.num_features == num_features * 2

    assert sub_batchnorm_3d.bn.affine is False
    assert sub_batchnorm_3d.split_bn.affine is False