"official/nlp/optimization.py" did not exist on "8625efd8c080f4f43f87667a5b57ea068d850f35"
source_separation_pipeline_test.py 1.95 KB
Newer Older
1
2
3
import os
import sys

4
import pytest
5
6
import torch
import torchaudio
7
8
from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX
from torchaudio.prototype.pipelines import HDEMUCS_HIGH_MUSDB, HDEMUCS_HIGH_MUSDB_PLUS
9
10
11


sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "examples"))
12
from source_separation.utils.metrics import sdr
13
14


15
16
17
18
19
@pytest.mark.parametrize(
    "bundle,task,channel,expected_score",
    [
        [CONVTASNET_BASE_LIBRI2MIX, "speech_separation", 1, 8.1373],
        [HDEMUCS_HIGH_MUSDB_PLUS, "music_separation", 2, 8.7480],
Sean Kim's avatar
Sean Kim committed
20
        [HDEMUCS_HIGH_MUSDB, "music_separation", 2, 8.0697],
21
22
23
    ],
)
def test_source_separation_models(bundle, task, channel, expected_score, mixture_source, clean_sources):
24
    """Integration test for the source separation pipeline.
25
    Given the mixture waveform with dimensions `(batch, channel, time)`, the pre-trained pipeline generates
26
    the separated sources Tensor with dimensions `(batch, num_sources, time)`.
27
    The test computes the scale-invariant signal-to-distortion ratio (Si-SDR) score in decibel (dB).
28
29
    Si-SDR score should be equal to or larger than the expected score.
    """
30
    model = bundle.get_model()
31
    mixture_waveform, sample_rate = torchaudio.load(mixture_source)
32
    assert sample_rate == bundle.sample_rate, "The sample rate of audio must match that in the bundle."
33
34
35
    clean_waveforms = []
    for source in clean_sources:
        clean_waveform, sample_rate = torchaudio.load(source)
36
        assert sample_rate == bundle.sample_rate, "The sample rate of audio must match that in the bundle."
37
        clean_waveforms.append(clean_waveform)
38
    mixture_waveform = mixture_waveform.reshape(1, channel, -1)
39
40
    estimated_sources = model(mixture_waveform)
    clean_waveforms = torch.cat(clean_waveforms).unsqueeze(0)
41
42
43
    estimated_sources = estimated_sources.reshape(1, -1, clean_waveforms.shape[-1])
    sdr_values = sdr(estimated_sources, clean_waveforms).mean()
    assert sdr_values >= expected_score