test_concat_dataset.py 1.16 KB
Newer Older
unknown's avatar
unknown committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np

from mmaction.datasets import ConcatDataset
from .base import BaseTestDataset


class TestConcatDataset(BaseTestDataset):

    def test_concat_dataset(self):
        dataset_cfg = dict(
            type='RawframeDataset',
            ann_file=self.frame_ann_file,
            pipeline=self.frame_pipeline,
            data_prefix=self.data_prefix)
        repeat_dataset_cfg = dict(
            type='RepeatDataset', times=2, dataset=dataset_cfg)

        concat_dataset = ConcatDataset(
            datasets=[dataset_cfg, repeat_dataset_cfg])

        assert len(concat_dataset) == 6
        result_a = concat_dataset[0]
        result_b = concat_dataset[4]
        assert set(result_a) == set(result_b)
        for key in result_a:
            if isinstance(result_a[key], np.ndarray):
                assert np.equal(result_a[key], result_b[key]).all()
            elif isinstance(result_a[key], list):
                assert all(
                    np.array_equal(a, b)
                    for (a, b) in zip(result_a[key], result_b[key]))
            else:
                assert result_a[key] == result_b[key]