Commit 4c4da32c authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

New Pipeline edits for HDemucs (#2565)

Summary:
Created new branch and brought in commits due to rebasing issues, resolved conflicts on new branch, close old branch.

Pull Request resolved: https://github.com/pytorch/audio/pull/2565

Reviewed By: nateanl, mthrok

Differential Revision: D38131189

Pulled By: skim0514

fbshipit-source-id: 96531480cf50562944abb28d70879f21b4609f15
parent 45f512f6
......@@ -45,6 +45,14 @@ CONVTASNET_BASE_LIBRI2MIX
.. autodata:: CONVTASNET_BASE_LIBRI2MIX
:no-value:
HDEMUCS_HIGH_MUSDB_PLUS
~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: HDEMUCS_HIGH_MUSDB_PLUS
:no-value:
References
----------
......
......@@ -43,11 +43,22 @@ _FILES = {
"fr": "20121212-0900-PLENARY-5-fr_20121212-11_37_04_10.flac",
"it": "20170516-0900-PLENARY-16-it_20170516-18_56_31_1.flac",
}
_MIXTURE_FILE = "mixture_3729-6852-0037_8463-287645-0000.wav"
_CLEAN_FILES = [
"s1_3729-6852-0037_8463-287645-0000.wav",
"s2_3729-6852-0037_8463-287645-0000.wav",
]
_MIXTURE_FILES = {
"speech_separation": "mixture_3729-6852-0037_8463-287645-0000.wav",
"music_separation": "al_james_mixture_shorter.wav",
}
_CLEAN_FILES = {
"speech_separation": [
"s1_3729-6852-0037_8463-287645-0000.wav",
"s2_3729-6852-0037_8463-287645-0000.wav",
],
"music_separation": [
"al_james_drums_shorter.wav",
"al_james_bass_shorter.wav",
"al_james_other_shorter.wav",
"al_james_vocals_shorter.wav",
],
}
@pytest.fixture
......@@ -60,15 +71,19 @@ def sample_speech(lang):
@pytest.fixture
def mixture_source():
path = torchaudio.utils.download_asset(f"test-assets/{_MIXTURE_FILE}")
def mixture_source(task):
if task not in _MIXTURE_FILES:
raise NotImplementedError(f"Unexpected task: {task}")
path = torchaudio.utils.download_asset(f"test-assets/{_MIXTURE_FILES[task]}")
return path
@pytest.fixture
def clean_sources():
def clean_sources(task):
if task not in _CLEAN_FILES:
raise NotImplementedError(f"Unexpected task: {task}")
paths = []
for file in _CLEAN_FILES:
for file in _CLEAN_FILES[task]:
path = torchaudio.utils.download_asset(f"test-assets/{file}")
paths.append(path)
return paths
......
import os
import sys
import pytest
import torch
import torchaudio
from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX
from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX, HDEMUCS_HIGH_MUSDB_PLUS
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "examples"))
from source_separation.utils.metrics import PIT, sdr
from source_separation.utils.metrics import sdr
def test_source_separation_models(mixture_source, clean_sources):
@pytest.mark.parametrize(
"bundle,task,channel,expected_score",
[
[CONVTASNET_BASE_LIBRI2MIX, "speech_separation", 1, 8.1373],
[HDEMUCS_HIGH_MUSDB_PLUS, "music_separation", 2, 8.7480],
],
)
def test_source_separation_models(bundle, task, channel, expected_score, mixture_source, clean_sources):
"""Integration test for the source separation pipeline.
Given the mixture waveform with dimensions `(batch, 1, time)`, the pre-trained pipeline generates
Given the mixture waveform with dimensions `(batch, channel, time)`, the pre-trained pipeline generates
the separated sources Tensor with dimensions `(batch, num_sources, time)`.
The test computes the scale-invariant signal-to-distortion ratio (Si-SDR) score in decibel (dB) with
permutation invariant training (PIT) criterion. PIT computes Si-SDR scores between the estimated sources and the
target sources for all permuations, then returns the highest values as the final output. The final
The test computes the scale-invariant signal-to-distortion ratio (Si-SDR) score in decibel (dB).
Si-SDR score should be equal to or larger than the expected score.
"""
BUNDLE = CONVTASNET_BASE_LIBRI2MIX
EXPECTED_SCORE = 8.1373 # expected Si-SDR score.
model = BUNDLE.get_model()
model = bundle.get_model()
mixture_waveform, sample_rate = torchaudio.load(mixture_source)
assert sample_rate == BUNDLE.sample_rate, "The sample rate of audio must match that in the bundle."
assert sample_rate == bundle.sample_rate, "The sample rate of audio must match that in the bundle."
clean_waveforms = []
for source in clean_sources:
clean_waveform, sample_rate = torchaudio.load(source)
assert sample_rate == BUNDLE.sample_rate, "The sample rate of audio must match that in the bundle."
assert sample_rate == bundle.sample_rate, "The sample rate of audio must match that in the bundle."
clean_waveforms.append(clean_waveform)
mixture_waveform = mixture_waveform.reshape(1, 1, -1)
mixture_waveform = mixture_waveform.reshape(1, channel, -1)
estimated_sources = model(mixture_waveform)
clean_waveforms = torch.cat(clean_waveforms).unsqueeze(0)
_sdr_pit = PIT(utility_func=sdr)
sdr_values = _sdr_pit(estimated_sources, clean_waveforms)
assert sdr_values >= EXPECTED_SCORE
estimated_sources = estimated_sources.reshape(1, -1, clean_waveforms.shape[-1])
sdr_values = sdr(estimated_sources, clean_waveforms).mean()
assert sdr_values >= expected_score
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .source_separation_pipeline import CONVTASNET_BASE_LIBRI2MIX, SourceSeparationBundle
from .source_separation_pipeline import CONVTASNET_BASE_LIBRI2MIX, HDEMUCS_HIGH_MUSDB_PLUS, SourceSeparationBundle
__all__ = [
"CONVTASNET_BASE_LIBRI2MIX",
"EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3",
"SourceSeparationBundle",
"HDEMUCS_HIGH_MUSDB_PLUS",
]
......@@ -5,7 +5,7 @@ from typing import Callable
import torch
import torchaudio
from torchaudio.prototype.models import conv_tasnet_base
from torchaudio.prototype.models import conv_tasnet_base, hdemucs_high
@dataclass
......@@ -75,3 +75,16 @@ CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained *ConvTasNet* [:footcite:`Luo_
Please refer to :py:class:`SourceSeparationBundle` for usage instructions.
"""
HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle(
_model_path="models/hdemucs_high_trained.pt",
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"], sample_rate=44100),
_sample_rate=44100,
)
HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained *Hybrid Demucs* [:footcite:`defossez2021hybrid`] pipeline for music
source separation. The underlying model is constructed by
:py:func:`torchaudio.prototyoe.models.hdemucs_high` and utilizes weights trained on MUSDB-HQ [:footcite:`MUSDB18HQ`]
and internal extra training data, all at the same sample rate of 44.1 kHZ. The model separates mixture music into
“drums”, “base”, “vocals”, and “other” sources. Training was performed in the original HDemucs repository
`here <https://github.com/facebookresearch/demucs/>`__.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment