Commit 6ecc11c2 authored by Sean Kim's avatar Sean Kim Committed by Facebook GitHub Bot
Browse files

Add HDEMUCS_HIGH_MUSDB (#2601)

Summary:
Add new model pretrained weights and tests

Pull Request resolved: https://github.com/pytorch/audio/pull/2601

Reviewed By: carolineechen, nateanl

Differential Revision: D38396673

Pulled By: skim0514

fbshipit-source-id: e06f97d28508543bc18e671344386a947bc870c1
parent 946b180a
......@@ -53,6 +53,14 @@ HDEMUCS_HIGH_MUSDB_PLUS
.. autodata:: HDEMUCS_HIGH_MUSDB_PLUS
:no-value:
HDEMUCS_HIGH_MUSDB
~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: HDEMUCS_HIGH_MUSDB
:no-value:
References
----------
......
......@@ -4,7 +4,7 @@ import sys
import pytest
import torch
import torchaudio
from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX, HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX, HDEMUCS_HIGH_MUSDB, HDEMUCS_HIGH_MUSDB_PLUS
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "examples"))
......@@ -16,6 +16,7 @@ from source_separation.utils.metrics import sdr
[
[CONVTASNET_BASE_LIBRI2MIX, "speech_separation", 1, 8.1373],
[HDEMUCS_HIGH_MUSDB_PLUS, "music_separation", 2, 8.7480],
[HDEMUCS_HIGH_MUSDB, "music_separation", 2, 8.0697],
],
)
def test_source_separation_models(bundle, task, channel, expected_score, mixture_source, clean_sources):
......
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .source_separation_pipeline import CONVTASNET_BASE_LIBRI2MIX, HDEMUCS_HIGH_MUSDB_PLUS, SourceSeparationBundle
from .source_separation_pipeline import (
CONVTASNET_BASE_LIBRI2MIX,
HDEMUCS_HIGH_MUSDB,
HDEMUCS_HIGH_MUSDB_PLUS,
SourceSeparationBundle,
)
__all__ = [
"CONVTASNET_BASE_LIBRI2MIX",
......@@ -7,4 +12,5 @@ __all__ = [
"EMFORMER_RNNT_BASE_TEDLIUM3",
"SourceSeparationBundle",
"HDEMUCS_HIGH_MUSDB_PLUS",
"HDEMUCS_HIGH_MUSDB",
]
......@@ -83,8 +83,21 @@ HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle(
)
HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained *Hybrid Demucs* [:footcite:`defossez2021hybrid`] pipeline for music
source separation. The underlying model is constructed by
:py:func:`torchaudio.prototyoe.models.hdemucs_high` and utilizes weights trained on MUSDB-HQ [:footcite:`MUSDB18HQ`]
:py:func:`torchaudio.prototype.models.hdemucs_high` and utilizes weights trained on MUSDB-HQ [:footcite:`MUSDB18HQ`]
and internal extra training data, all at the same sample rate of 44.1 kHZ. The model separates mixture music into
“drums”, “base”, “vocals”, and “other” sources. Training was performed in the original HDemucs repository
`here <https://github.com/facebookresearch/demucs/>`__.
"""
HDEMUCS_HIGH_MUSDB = SourceSeparationBundle(
_model_path="models/hdemucs_high_musdbhq_only.pt",
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
_sample_rate=44100,
)
HDEMUCS_HIGH_MUSDB.__doc__ = """Pre-trained *Hybrid Demucs* [:footcite:`defossez2021hybrid`] pipeline for music
source separation. The underlying model is constructed by
:py:func:`torchaudio.prototype.models.hdemucs_high` and utilizes weights trained on only
MUSDB-HQ [:footcite:`MUSDB18HQ`] at the same sample rate of 44.1 kHZ. The model separates mixture music into
“drums”, “base”, “vocals”, and “other” sources. Training was performed in the original HDemucs repository
`here <https://github.com/facebookresearch/demucs/>`__.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment