Commit 83362580 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add SourceSeparationBundle to prototype (#2440)

Summary:
- Add SourceSeparationBundle class for source separation pipeline
- Add `CONVTASNET_BASE_LIBRI2MIX` that is trained on Libri2Mix dataset.
- Add integration test with example mixture audio and expected scale-invariant signal-to-distortion ratio (Si-SDR) score. The test computes the Si-SDR score with permutation-invariant training (PIT) criterion for all permutations of sources and use the highest value as the final output. The test verifies if the score is equal to or larger than the expected value.

Pull Request resolved: https://github.com/pytorch/audio/pull/2440

Reviewed By: mthrok

Differential Revision: D37997646

Pulled By: nateanl

fbshipit-source-id: c951bcbbe8b7ed9553cb8793d6dc1ef90d5a29fe
parent 5c6e602c
import os
import pytest
import torch
import torchaudio
......@@ -40,6 +42,11 @@ _FILES = {
"fr": "20121212-0900-PLENARY-5-fr_20121212-11_37_04_10.flac",
"it": "20170516-0900-PLENARY-16-it_20170516-18_56_31_1.flac",
}
_MIXTURE_FILE = "mixture_3729-6852-0037_8463-287645-0000.wav"
_CLEAN_FILES = [
"s1_3729-6852-0037_8463-287645-0000.wav",
"s2_3729-6852-0037_8463-287645-0000.wav",
]
@pytest.fixture
......@@ -53,6 +60,21 @@ def sample_speech(tmp_path, lang):
return path
@pytest.fixture
def mixture_source():
path = torchaudio.utils.download_asset(os.path.join("test-assets", f"{_MIXTURE_FILE}"))
return path
@pytest.fixture
def clean_sources():
paths = []
for file in _CLEAN_FILES:
path = torchaudio.utils.download_asset(os.path.join("test-assets", f"{file}"))
paths.append(path)
return paths
def pytest_addoption(parser):
parser.addoption(
"--use-tmp-hub-dir",
......
import os
import sys
import torch
import torchaudio
from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "examples"))
from source_separation.utils.metrics import PIT, sdr
def test_source_separation_models(mixture_source, clean_sources):
"""Integration test for the source separation pipeline.
Given the mixture waveform with dimensions `(batch, 1, time)`, the pre-trained pipeline generates
the separated sources Tensor with dimensions `(batch, num_sources, time)`.
The test computes the scale-invariant signal-to-distortion ratio (Si-SDR) score in decibel (dB) with
permutation invariant training (PIT) criterion. PIT computes Si-SDR scores between the estimated sources and the
target sources for all permuations, then returns the highest values as the final output. The final
Si-SDR score should be equal to or larger than the expected score.
"""
BUNDLE = CONVTASNET_BASE_LIBRI2MIX
EXPECTED_SCORE = 8.1373 # expected Si-SDR score.
model = BUNDLE.get_model()
mixture_waveform, sample_rate = torchaudio.load(mixture_source)
assert sample_rate == BUNDLE.sample_rate, "The sample rate of audio must match that in the bundle."
clean_waveforms = []
for source in clean_sources:
clean_waveform, sample_rate = torchaudio.load(source)
assert sample_rate == BUNDLE.sample_rate, "The sample rate of audio must match that in the bundle."
clean_waveforms.append(clean_waveform)
mixture_waveform = mixture_waveform.reshape(1, 1, -1)
estimated_sources = model(mixture_waveform)
clean_waveforms = torch.cat(clean_waveforms).unsqueeze(0)
_sdr_pit = PIT(utility_func=sdr)
sdr_values = _sdr_pit(estimated_sources, clean_waveforms)
assert sdr_values >= EXPECTED_SCORE
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .source_separation_pipeline import CONVTASNET_BASE_LIBRI2MIX
__all__ = [
"CONVTASNET_BASE_LIBRI2MIX",
"EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3",
]
from dataclasses import dataclass
from functools import partial
from typing import Callable
import torch
import torchaudio
from torchaudio.prototype.models import conv_tasnet_base
@dataclass
class SourceSeparationBundle:
"""torchaudio.prototype.pipelines.SourceSeparationBundle()
Dataclass that bundles components for performing source separation.
Example
>>> import torchaudio
>>> from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX
>>> import torch
>>>
>>> # Build the separation model.
>>> model = CONVTASNET_BASE_LIBRI2MIX.get_model()
>>> 100%|███████████████████████████████|19.1M/19.1M [00:04<00:00, 4.93MB/s]
>>>
>>> # Instantiate the test set of Libri2Mix dataset.
>>> dataset = torchaudio.datasets.LibriMix("/home/datasets/", subset="test")
>>>
>>> # Apply source separation on mixture audio.
>>> for i, data in enumerate(dataset):
>>> sample_rate, mixture, clean_sources = data
>>> # Make sure the shape of input suits the model requirement.
>>> mixture = mixture.reshape(1, 1, -1)
>>> estimated_sources = model(mixture)
>>> score = si_snr_pit(estimated_sources, clean_sources) # for demonstration
>>> print(f"Si-SNR score is : {score}.)
>>> break
>>> Si-SNR score is : 16.24.
>>>
"""
_model_path: str
_model_factory_func: Callable[[], torch.nn.Module]
_sample_rate: int
@property
def sample_rate(self) -> int:
"""Sample rate (in cycles per second) of input waveforms.
:type: int
"""
return self._sample_rate
def get_model(self) -> torch.nn.Module:
model = self._model_factory_func()
path = torchaudio.utils.download_asset(self._model_path)
state_dict = torch.load(path)
model.load_state_dict(state_dict)
model.eval()
return model
CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle(
_model_path="models/conv_tasnet_base_libri2mix.pt",
_model_factory_func=partial(conv_tasnet_base, num_sources=2),
_sample_rate=8000,
)
CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained ConvTasNet pipeline for source separation.
The underlying model is constructed by :py:func:`torchaudio.prototyoe.models.conv_tasnet_base`
and utilizes weights trained on Libri2Mix using training script ``lightning_train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/source_separation/>`__ with default arguments.
Please refer to :py:class:`SourceSeparationBundle` for usage instructions.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment