source_separation_pipeline_test.py 1.81 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import sys

import torch
import torchaudio
from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX


sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "examples"))
from source_separation.utils.metrics import PIT, sdr


def test_source_separation_models(mixture_source, clean_sources):
    """Integration test for the source separation pipeline.
    Given the mixture waveform with dimensions `(batch, 1, time)`, the pre-trained pipeline generates
    the separated sources Tensor with dimensions `(batch, num_sources, time)`.
    The test computes the scale-invariant signal-to-distortion ratio (Si-SDR) score in decibel (dB) with
    permutation invariant training (PIT) criterion. PIT computes Si-SDR scores between the estimated sources and the
    target sources for all permuations, then returns the highest values as the final output. The final
    Si-SDR score should be equal to or larger than the expected score.
    """
    BUNDLE = CONVTASNET_BASE_LIBRI2MIX
    EXPECTED_SCORE = 8.1373  # expected Si-SDR score.
    model = BUNDLE.get_model()
    mixture_waveform, sample_rate = torchaudio.load(mixture_source)
    assert sample_rate == BUNDLE.sample_rate, "The sample rate of audio must match that in the bundle."
    clean_waveforms = []
    for source in clean_sources:
        clean_waveform, sample_rate = torchaudio.load(source)
        assert sample_rate == BUNDLE.sample_rate, "The sample rate of audio must match that in the bundle."
        clean_waveforms.append(clean_waveform)
    mixture_waveform = mixture_waveform.reshape(1, 1, -1)
    estimated_sources = model(mixture_waveform)
    clean_waveforms = torch.cat(clean_waveforms).unsqueeze(0)
    _sdr_pit = PIT(utility_func=sdr)
    sdr_values = _sdr_pit(estimated_sources, clean_waveforms)
    assert sdr_values >= EXPECTED_SCORE