Commit 60868748 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Move Hybrid Demucs pipeline to beta (#2673)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2673

Reviewed By: mthrok

Differential Revision: D39507612

Pulled By: carolineechen

fbshipit-source-id: 3a9ee53f72cabd6e3085c76867017be4a6ed7f53
parent 155dc298
......@@ -334,6 +334,22 @@ CONVTASNET_BASE_LIBRI2MIX
.. autodata:: CONVTASNET_BASE_LIBRI2MIX
:no-value:
HDEMUCS_HIGH_MUSDB_PLUS
~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: HDEMUCS_HIGH_MUSDB_PLUS
:no-value:
HDEMUCS_HIGH_MUSDB
~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: HDEMUCS_HIGH_MUSDB
:no-value:
References
----------
......
......@@ -25,25 +25,6 @@ EMFORMER_RNNT_BASE_TEDLIUM3
.. autodata:: EMFORMER_RNNT_BASE_TEDLIUM3
:no-value:
Source Separation
-----------------
HDEMUCS_HIGH_MUSDB_PLUS
~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: HDEMUCS_HIGH_MUSDB_PLUS
:no-value:
HDEMUCS_HIGH_MUSDB
~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: HDEMUCS_HIGH_MUSDB
:no-value:
References
----------
......
......@@ -54,7 +54,7 @@ from torchaudio.utils import download_asset
import matplotlib.pyplot as plt
try:
from torchaudio.prototype.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from mir_eval import separation
except ModuleNotFoundError:
......
......@@ -4,8 +4,7 @@ import sys
import pytest
import torch
import torchaudio
from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX
from torchaudio.prototype.pipelines import HDEMUCS_HIGH_MUSDB, HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX, HDEMUCS_HIGH_MUSDB, HDEMUCS_HIGH_MUSDB_PLUS
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "examples"))
......
from ._source_separation_pipeline import CONVTASNET_BASE_LIBRI2MIX, SourceSeparationBundle
from ._source_separation_pipeline import (
CONVTASNET_BASE_LIBRI2MIX,
HDEMUCS_HIGH_MUSDB,
HDEMUCS_HIGH_MUSDB_PLUS,
SourceSeparationBundle,
)
from ._tts import (
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
......@@ -71,4 +76,6 @@ __all__ = [
"EMFORMER_RNNT_BASE_LIBRISPEECH",
"SourceSeparationBundle",
"CONVTASNET_BASE_LIBRI2MIX",
"HDEMUCS_HIGH_MUSDB_PLUS",
"HDEMUCS_HIGH_MUSDB",
]
......@@ -5,7 +5,7 @@ from typing import Callable
import torch
import torchaudio
from torchaudio.models import conv_tasnet_base
from torchaudio.models import conv_tasnet_base, hdemucs_high
@dataclass
......@@ -76,3 +76,33 @@ CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained Source Separation pipeline wi
Please refer to :py:class:`SourceSeparationBundle` for usage instructions.
"""
HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle(
_model_path="models/hdemucs_high_trained.pt",
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
_sample_rate=44100,
)
HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained *Hybrid Demucs* [:footcite:`defossez2021hybrid`] pipeline for music
source separation trained on MUSDB-HQ [:footcite:`MUSDB18HQ`] and additional internal training data.
The model is constructed by :py:func:`torchaudio.prototype.models.hdemucs_high`.
Training was performed in the original HDemucs repository `here <https://github.com/facebookresearch/demucs/>`__.
Please refer to :py:class:`SourceSeparationBundle` for usage instructions.
"""
HDEMUCS_HIGH_MUSDB = SourceSeparationBundle(
_model_path="models/hdemucs_high_musdbhq_only.pt",
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
_sample_rate=44100,
)
HDEMUCS_HIGH_MUSDB.__doc__ = """Pre-trained *Hybrid Demucs* [:footcite:`defossez2021hybrid`] pipeline for music
source separation trained on MUSDB-HQ [:footcite:`MUSDB18HQ`].
The model is constructed by :py:func:`torchaudio.prototype.models.hdemucs_high`.
Training was performed in the original HDemucs repository `here <https://github.com/facebookresearch/demucs/>`__.
Please refer to :py:class:`SourceSeparationBundle` for usage instructions.
"""
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .source_separation_pipeline import HDEMUCS_HIGH_MUSDB, HDEMUCS_HIGH_MUSDB_PLUS
__all__ = [
"EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3",
"HDEMUCS_HIGH_MUSDB_PLUS",
"HDEMUCS_HIGH_MUSDB",
]
from functools import partial
from torchaudio.models import hdemucs_high
from torchaudio.pipelines import SourceSeparationBundle
HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle(
_model_path="models/hdemucs_high_trained.pt",
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
_sample_rate=44100,
)
HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained *Hybrid Demucs* [:footcite:`defossez2021hybrid`] pipeline for music
source separation. The underlying model is constructed by
:py:func:`torchaudio.prototype.models.hdemucs_high` and utilizes weights trained on MUSDB-HQ [:footcite:`MUSDB18HQ`]
and internal extra training data, all at the same sample rate of 44.1 kHZ. The model separates mixture music into
“drums”, “base”, “vocals”, and “other” sources. Training was performed in the original HDemucs repository
`here <https://github.com/facebookresearch/demucs/>`__.
"""
HDEMUCS_HIGH_MUSDB = SourceSeparationBundle(
_model_path="models/hdemucs_high_musdbhq_only.pt",
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
_sample_rate=44100,
)
HDEMUCS_HIGH_MUSDB.__doc__ = """Pre-trained *Hybrid Demucs* [:footcite:`defossez2021hybrid`] pipeline for music
source separation. The underlying model is constructed by
:py:func:`torchaudio.prototype.models.hdemucs_high` and utilizes weights trained on only
MUSDB-HQ [:footcite:`MUSDB18HQ`] at the same sample rate of 44.1 kHZ. The model separates mixture music into
“drums”, “base”, “vocals”, and “other” sources. Training was performed in the original HDemucs repository
`here <https://github.com/facebookresearch/demucs/>`__.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment