test_audio_dataset.py 2.59 KB
Newer Older
unknown's avatar
unknown committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp

import numpy as np
import pytest
from mmcv.utils import assert_dict_has_keys

from mmaction.datasets import AudioDataset
from .base import BaseTestDataset


class TestAudioDataset(BaseTestDataset):

    def test_audio_dataset(self):
        audio_dataset = AudioDataset(
            self.audio_ann_file,
            self.audio_pipeline,
            data_prefix=self.data_prefix)
        audio_infos = audio_dataset.video_infos
        wav_path = osp.join(self.data_prefix, 'test.wav')
        assert audio_infos == [
            dict(audio_path=wav_path, total_frames=100, label=127)
        ] * 2

    def test_audio_pipeline(self):
        target_keys = [
            'audio_path', 'label', 'start_index', 'modality', 'audios_shape',
            'length', 'sample_rate', 'total_frames'
        ]

        # Audio dataset not in test mode
        audio_dataset = AudioDataset(
            self.audio_ann_file,
            self.audio_pipeline,
            data_prefix=self.data_prefix,
            test_mode=False)
        result = audio_dataset[0]
        assert assert_dict_has_keys(result, target_keys)

        # Audio dataset in test mode
        audio_dataset = AudioDataset(
            self.audio_ann_file,
            self.audio_pipeline,
            data_prefix=self.data_prefix,
            test_mode=True)
        result = audio_dataset[0]
        assert assert_dict_has_keys(result, target_keys)

    def test_audio_evaluate(self):
        audio_dataset = AudioDataset(
            self.audio_ann_file,
            self.audio_pipeline,
            data_prefix=self.data_prefix)

        with pytest.raises(TypeError):
            # results must be a list
            audio_dataset.evaluate('0.5')

        with pytest.raises(AssertionError):
            # The length of results must be equal to the dataset len
            audio_dataset.evaluate([0] * 5)

        with pytest.raises(TypeError):
            # topk must be int or tuple of int
            audio_dataset.evaluate(
                [0] * len(audio_dataset),
                metric_options=dict(top_k_accuracy=dict(topk=1.)))

        with pytest.raises(KeyError):
            # unsupported metric
            audio_dataset.evaluate([0] * len(audio_dataset), metrics='iou')

        # evaluate top_k_accuracy and mean_class_accuracy metric
        results = [np.array([0.1, 0.5, 0.4])] * 2
        eval_result = audio_dataset.evaluate(
            results, metrics=['top_k_accuracy', 'mean_class_accuracy'])
        assert set(eval_result) == set(
            ['top1_acc', 'top5_acc', 'mean_class_accuracy'])