# Copyright (c) OpenMMLab. All rights reserved. import numpy as np import pytest from mmcv.utils import assert_keys_equal, digit_version from mmaction.datasets.pipelines import Compose, ImageToTensor try: import torchvision torchvision_ok = False if digit_version(torchvision.__version__) >= digit_version('0.8.0'): torchvision_ok = True except (ImportError, ModuleNotFoundError): torchvision_ok = False def test_compose(): with pytest.raises(TypeError): # transform must be callable or a dict Compose('LoadImage') target_keys = ['img', 'img_metas'] # test Compose given a data pipeline img = np.random.randn(256, 256, 3) results = dict(img=img, abandoned_key=None, img_name='test_image.png') test_pipeline = [ dict(type='Collect', keys=['img'], meta_keys=['img_name']), dict(type='ImageToTensor', keys=['img']) ] compose = Compose(test_pipeline) compose_results = compose(results) assert assert_keys_equal(compose_results.keys(), target_keys) assert assert_keys_equal(compose_results['img_metas'].data.keys(), ['img_name']) # test Compose when forward data is None results = None image_to_tensor = ImageToTensor(keys=[]) test_pipeline = [image_to_tensor] compose = Compose(test_pipeline) compose_results = compose(results) assert compose_results is None assert repr(compose) == compose.__class__.__name__ + \ f'(\n {image_to_tensor}\n)' @pytest.mark.skipif( not torchvision_ok, reason='torchvision >= 0.8.0 is required') def test_compose_support_torchvision(): target_keys = ['imgs', 'img_metas'] # test Compose given a data pipeline imgs = [np.random.randn(256, 256, 3)] * 8 results = dict( imgs=imgs, abandoned_key=None, img_name='test_image.png', clip_len=8, num_clips=1) test_pipeline = [ dict(type='torchvision.Grayscale', num_output_channels=3), dict(type='FormatShape', input_format='NCTHW'), dict(type='Collect', keys=['imgs'], meta_keys=['img_name']), dict(type='ToTensor', keys=['imgs']) ] compose = Compose(test_pipeline) compose_results = compose(results) assert assert_keys_equal(compose_results.keys(), target_keys) assert assert_keys_equal(compose_results['img_metas'].data.keys(), ['img_name'])