test_audio_recognizer.py 914 Bytes
Newer Older
unknown's avatar
unknown committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmaction.models import build_recognizer
from ..base import generate_recognizer_demo_inputs, get_audio_recognizer_cfg


def test_audio_recognizer():
    config = get_audio_recognizer_cfg(
        'resnet/tsn_r18_64x1x1_100e_kinetics400_audio_feature.py')
    config.model['backbone']['pretrained'] = None

    recognizer = build_recognizer(config.model)

    input_shape = (1, 3, 1, 128, 80)
    demo_inputs = generate_recognizer_demo_inputs(
        input_shape, model_type='audio')

    audios = demo_inputs['imgs']
    gt_labels = demo_inputs['gt_labels']

    losses = recognizer(audios, gt_labels)
    assert isinstance(losses, dict)

    # Test forward test
    with torch.no_grad():
        audio_list = [audio[None, :] for audio in audios]
        for one_spectro in audio_list:
            recognizer(one_spectro, None, return_loss=False)