test_compose.py 2.38 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
from mmcv.utils import assert_keys_equal, digit_version

from mmaction.datasets.pipelines import Compose, ImageToTensor

try:
    import torchvision
    torchvision_ok = False
    if digit_version(torchvision.__version__) >= digit_version('0.8.0'):
        torchvision_ok = True
except (ImportError, ModuleNotFoundError):
    torchvision_ok = False


def test_compose():
    with pytest.raises(TypeError):
        # transform must be callable or a dict
        Compose('LoadImage')

    target_keys = ['img', 'img_metas']

    # test Compose given a data pipeline
    img = np.random.randn(256, 256, 3)
    results = dict(img=img, abandoned_key=None, img_name='test_image.png')
    test_pipeline = [
        dict(type='Collect', keys=['img'], meta_keys=['img_name']),
        dict(type='ImageToTensor', keys=['img'])
    ]
    compose = Compose(test_pipeline)
    compose_results = compose(results)
    assert assert_keys_equal(compose_results.keys(), target_keys)
    assert assert_keys_equal(compose_results['img_metas'].data.keys(),
                             ['img_name'])

    # test Compose when forward data is None
    results = None
    image_to_tensor = ImageToTensor(keys=[])
    test_pipeline = [image_to_tensor]
    compose = Compose(test_pipeline)
    compose_results = compose(results)
    assert compose_results is None

    assert repr(compose) == compose.__class__.__name__ + \
        f'(\n    {image_to_tensor}\n)'


@pytest.mark.skipif(
    not torchvision_ok, reason='torchvision >= 0.8.0 is required')
def test_compose_support_torchvision():
    target_keys = ['imgs', 'img_metas']

    # test Compose given a data pipeline
    imgs = [np.random.randn(256, 256, 3)] * 8
    results = dict(
        imgs=imgs,
        abandoned_key=None,
        img_name='test_image.png',
        clip_len=8,
        num_clips=1)
    test_pipeline = [
        dict(type='torchvision.Grayscale', num_output_channels=3),
        dict(type='FormatShape', input_format='NCTHW'),
        dict(type='Collect', keys=['imgs'], meta_keys=['img_name']),
        dict(type='ToTensor', keys=['imgs'])
    ]
    compose = Compose(test_pipeline)
    compose_results = compose(results)
    assert assert_keys_equal(compose_results.keys(), target_keys)
    assert assert_keys_equal(compose_results['img_metas'].data.keys(),
                             ['img_name'])