Commit 4d535e88 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Move SourceSeparationBundle and pre-trained ConvTasNet pipeline into Beta (#2669)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2669

Reviewed By: carolineechen, mthrok

Differential Revision: D39433560

Pulled By: nateanl

fbshipit-source-id: 5b652b31c00badb37b27a32ac25b422a5bcc74cb
parent 697e15ab
......@@ -314,6 +314,26 @@ TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH
.. autodata:: TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH
:no-value:
Source Separation
-----------------
SourceSeparationBundle
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SourceSeparationBundle
:members: sample_rate
.. automethod:: get_model
CONVTASNET_BASE_LIBRI2MIX
~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: CONVTASNET_BASE_LIBRI2MIX
:no-value:
References
----------
......
......@@ -28,23 +28,6 @@ EMFORMER_RNNT_BASE_TEDLIUM3
Source Separation
-----------------
SourceSeparationBundle
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SourceSeparationBundle
:members: sample_rate
.. automethod:: get_model
CONVTASNET_BASE_LIBRI2MIX
~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: CONVTASNET_BASE_LIBRI2MIX
:no-value:
HDEMUCS_HIGH_MUSDB_PLUS
~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -4,7 +4,8 @@ import sys
import pytest
import torch
import torchaudio
from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX, HDEMUCS_HIGH_MUSDB, HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX
from torchaudio.prototype.pipelines import HDEMUCS_HIGH_MUSDB, HDEMUCS_HIGH_MUSDB_PLUS
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "examples"))
......
from ._hdemucs import HDemucs, hdemucs_high, hdemucs_low, hdemucs_medium
from .conformer import Conformer
from .conv_tasnet import ConvTasNet
from .conv_tasnet import conv_tasnet_base, ConvTasNet
from .deepspeech import DeepSpeech
from .emformer import Emformer
from .rnnt import emformer_rnnt_base, emformer_rnnt_model, RNNT
......@@ -29,6 +29,7 @@ __all__ = [
"Wav2Letter",
"WaveRNN",
"ConvTasNet",
"conv_tasnet_base",
"DeepSpeech",
"Wav2Vec2Model",
"HuBERTPretrainModel",
......
......@@ -299,3 +299,31 @@ class ConvTasNet(torch.nn.Module):
if num_pads > 0:
output = output[..., :-num_pads] # B, S, L
return output
def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet:
r"""Builds the non-causal version of ConvTasNet in
*Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
[:footcite:`Luo_2019`].
The parameter settings follow the ones with the highest Si-SNR metirc score in the paper,
except the mask activation function is changed from "sigmoid" to "relu" for performance improvement.
Args:
num_sources (int, optional): Number of sources in the output.
(Default: 2)
Returns:
ConvTasNet:
ConvTasNet model.
"""
return ConvTasNet(
num_sources=num_sources,
enc_kernel_size=16,
enc_num_feats=512,
msk_kernel_size=3,
msk_num_feats=128,
msk_num_hidden_feats=512,
msk_num_layers=8,
msk_num_stacks=3,
msk_activate="relu",
)
from ._source_separation_pipeline import CONVTASNET_BASE_LIBRI2MIX, SourceSeparationBundle
from ._tts import (
TACOTRON2_GRIFFINLIM_CHAR_LJSPEECH,
TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH,
......@@ -68,4 +69,6 @@ __all__ = [
"TACOTRON2_WAVERNN_PHONE_LJSPEECH",
"RNNTBundle",
"EMFORMER_RNNT_BASE_LIBRISPEECH",
"SourceSeparationBundle",
"CONVTASNET_BASE_LIBRI2MIX",
]
from dataclasses import dataclass
from functools import partial
from typing import Callable
import torch
import torchaudio
from torchaudio.models import conv_tasnet_base
@dataclass
class SourceSeparationBundle:
"""torchaudio.pipelines.SourceSeparationBundle()
Dataclass that bundles components for performing source separation.
Example
>>> import torchaudio
>>> from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX
>>> import torch
>>>
>>> # Build the separation model.
>>> model = CONVTASNET_BASE_LIBRI2MIX.get_model()
>>> 100%|███████████████████████████████|19.1M/19.1M [00:04<00:00, 4.93MB/s]
>>>
>>> # Instantiate the test set of Libri2Mix dataset.
>>> dataset = torchaudio.datasets.LibriMix("/home/datasets/", subset="test")
>>>
>>> # Apply source separation on mixture audio.
>>> for i, data in enumerate(dataset):
>>> sample_rate, mixture, clean_sources = data
>>> # Make sure the shape of input suits the model requirement.
>>> mixture = mixture.reshape(1, 1, -1)
>>> estimated_sources = model(mixture)
>>> score = si_snr_pit(estimated_sources, clean_sources) # for demonstration
>>> print(f"Si-SNR score is : {score}.)
>>> break
>>> Si-SNR score is : 16.24.
>>>
"""
_model_path: str
_model_factory_func: Callable[[], torch.nn.Module]
_sample_rate: int
@property
def sample_rate(self) -> int:
"""Sample rate of the audio that the model is trained on.
:type: int
"""
return self._sample_rate
def get_model(self) -> torch.nn.Module:
"""Construct the model and load the pretrained weight."""
model = self._model_factory_func()
path = torchaudio.utils.download_asset(self._model_path)
state_dict = torch.load(path)
model.load_state_dict(state_dict)
model.eval()
return model
CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle(
_model_path="models/conv_tasnet_base_libri2mix.pt",
_model_factory_func=partial(conv_tasnet_base, num_sources=2),
_sample_rate=8000,
)
CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained Source Separation pipeline with *ConvTasNet* [:footcite:`Luo_2019`] trained on
*Libri2Mix dataset* [:footcite:`cosentino2020librimix`].
The source separation model is constructed by :py:func:`torchaudio.models.conv_tasnet_base`
and is trained using the training script ``lightning_train.py``
`here <https://github.com/pytorch/audio/tree/release/0.12/examples/source_separation/>`__
with default arguments.
Please refer to :py:class:`SourceSeparationBundle` for usage instructions.
"""
from .conv_emformer import ConvEmformer
from .conv_tasnet import conv_tasnet_base
from .rnnt import conformer_rnnt_base, conformer_rnnt_model
__all__ = [
"conformer_rnnt_base",
"conformer_rnnt_model",
"conv_tasnet_base",
"ConvEmformer",
]
from torchaudio.models import ConvTasNet
def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet:
r"""Builds the non-causal version of ConvTasNet in
*Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
[:footcite:`Luo_2019`].
The paramter settings follow the ones with the highest Si-SNR metirc score in the paper,
except the mask activation function is changed from "sigmoid" to "relu" for performance improvement.
Args:
num_sources (int, optional): Number of sources in the output.
(Default: 2)
Returns:
ConvTasNet:
ConvTasNet model.
"""
return ConvTasNet(
num_sources=num_sources,
enc_kernel_size=16,
enc_num_feats=512,
msk_kernel_size=3,
msk_num_feats=128,
msk_num_hidden_feats=512,
msk_num_layers=8,
msk_num_stacks=3,
msk_activate="relu",
)
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .source_separation_pipeline import (
CONVTASNET_BASE_LIBRI2MIX,
HDEMUCS_HIGH_MUSDB,
HDEMUCS_HIGH_MUSDB_PLUS,
SourceSeparationBundle,
)
from .source_separation_pipeline import HDEMUCS_HIGH_MUSDB, HDEMUCS_HIGH_MUSDB_PLUS
__all__ = [
"CONVTASNET_BASE_LIBRI2MIX",
"EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3",
"SourceSeparationBundle",
"HDEMUCS_HIGH_MUSDB_PLUS",
"HDEMUCS_HIGH_MUSDB",
]
from dataclasses import dataclass
from functools import partial
from typing import Callable
import torch
import torchaudio
from torchaudio.models import hdemucs_high
from torchaudio.prototype.models import conv_tasnet_base
@dataclass
class SourceSeparationBundle:
"""torchaudio.prototype.pipelines.SourceSeparationBundle()
Dataclass that bundles components for performing source separation.
Example
>>> import torchaudio
>>> from torchaudio.prototype.pipelines import CONVTASNET_BASE_LIBRI2MIX
>>> import torch
>>>
>>> # Build the separation model.
>>> model = CONVTASNET_BASE_LIBRI2MIX.get_model()
>>> 100%|███████████████████████████████|19.1M/19.1M [00:04<00:00, 4.93MB/s]
>>>
>>> # Instantiate the test set of Libri2Mix dataset.
>>> dataset = torchaudio.datasets.LibriMix("/home/datasets/", subset="test")
>>>
>>> # Apply source separation on mixture audio.
>>> for i, data in enumerate(dataset):
>>> sample_rate, mixture, clean_sources = data
>>> # Make sure the shape of input suits the model requirement.
>>> mixture = mixture.reshape(1, 1, -1)
>>> estimated_sources = model(mixture)
>>> score = si_snr_pit(estimated_sources, clean_sources) # for demonstration
>>> print(f"Si-SNR score is : {score}.)
>>> break
>>> Si-SNR score is : 16.24.
>>>
"""
_model_path: str
_model_factory_func: Callable[[], torch.nn.Module]
_sample_rate: int
@property
def sample_rate(self) -> int:
"""Sample rate of the audio that the model is trained on.
:type: int
"""
return self._sample_rate
def get_model(self) -> torch.nn.Module:
"""Construct the model and load the pretrained weight."""
model = self._model_factory_func()
path = torchaudio.utils.download_asset(self._model_path)
state_dict = torch.load(path)
model.load_state_dict(state_dict)
model.eval()
return model
CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle(
_model_path="models/conv_tasnet_base_libri2mix.pt",
_model_factory_func=partial(conv_tasnet_base, num_sources=2),
_sample_rate=8000,
)
CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained *ConvTasNet* [:footcite:`Luo_2019`] pipeline for source separation.
The underlying model is constructed by :py:func:`torchaudio.prototyoe.models.conv_tasnet_base`
and utilizes weights trained on *Libri2Mix dataset* [:footcite:`cosentino2020librimix`] using training script
``lightning_train.py`` `here <https://github.com/pytorch/audio/tree/release/0.12/examples/source_separation/>`__
with default arguments.
from torchaudio.pipelines import SourceSeparationBundle
Please refer to :py:class:`SourceSeparationBundle` for usage instructions.
"""
HDEMUCS_HIGH_MUSDB_PLUS = SourceSeparationBundle(
_model_path="models/hdemucs_high_trained.pt",
......@@ -90,6 +18,7 @@ HDEMUCS_HIGH_MUSDB_PLUS.__doc__ = """Pre-trained *Hybrid Demucs* [:footcite:`def
`here <https://github.com/facebookresearch/demucs/>`__.
"""
HDEMUCS_HIGH_MUSDB = SourceSeparationBundle(
_model_path="models/hdemucs_high_musdbhq_only.pt",
_model_factory_func=partial(hdemucs_high, sources=["drums", "bass", "other", "vocals"]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment