source_separation_pipeline_test.py 1.92 KB
Newer Older
1
2
3
import os
import sys

4
import pytest
5
6
import torch
import torchaudio
Sean Kim's avatar
Sean Kim committed
7
from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX, HDEMUCS_HIGH_MUSDB, HDEMUCS_HIGH_MUSDB_PLUS
8
9
10


sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "examples"))
11
from source_separation.utils.metrics import sdr
12
13


14
15
16
17
18
@pytest.mark.parametrize(
    "bundle,task,channel,expected_score",
    [
        [CONVTASNET_BASE_LIBRI2MIX, "speech_separation", 1, 8.1373],
        [HDEMUCS_HIGH_MUSDB_PLUS, "music_separation", 2, 8.7480],
Sean Kim's avatar
Sean Kim committed
19
        [HDEMUCS_HIGH_MUSDB, "music_separation", 2, 8.0697],
20
21
22
    ],
)
def test_source_separation_models(bundle, task, channel, expected_score, mixture_source, clean_sources):
23
    """Integration test for the source separation pipeline.
24
    Given the mixture waveform with dimensions `(batch, channel, time)`, the pre-trained pipeline generates
25
    the separated sources Tensor with dimensions `(batch, num_sources, time)`.
26
    The test computes the scale-invariant signal-to-distortion ratio (Si-SDR) score in decibel (dB).
27
28
    Si-SDR score should be equal to or larger than the expected score.
    """
29
    model = bundle.get_model()
30
    mixture_waveform, sample_rate = torchaudio.load(mixture_source)
31
    assert sample_rate == bundle.sample_rate, "The sample rate of audio must match that in the bundle."
32
33
34
    clean_waveforms = []
    for source in clean_sources:
        clean_waveform, sample_rate = torchaudio.load(source)
35
        assert sample_rate == bundle.sample_rate, "The sample rate of audio must match that in the bundle."
36
        clean_waveforms.append(clean_waveform)
37
    mixture_waveform = mixture_waveform.reshape(1, channel, -1)
38
39
    estimated_sources = model(mixture_waveform)
    clean_waveforms = torch.cat(clean_waveforms).unsqueeze(0)
40
41
42
    estimated_sources = estimated_sources.reshape(1, -1, clean_waveforms.shape[-1])
    sdr_values = sdr(estimated_sources, clean_waveforms).mean()
    assert sdr_values >= expected_score