Commit 6cee56ab authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add documents for SourceSeparationBundle (#2559)

Summary:
- Add documentation page for `SourceSeparationBundle` and `CONVTASNET_BASE_LIBRI2MIX`.
- Add citation of Libri2Mix dataset in the bundle documentation.
- url in integration test should use slash instead of `os.path.join` as it will fail on Windows. Change it to f-string.

Pull Request resolved: https://github.com/pytorch/audio/pull/2559

Reviewed By: carolineechen

Differential Revision: D38036116

Pulled By: nateanl

fbshipit-source-id: 736732805191113955badfec3955e2e24e8f4836
parent c18a103b
...@@ -24,3 +24,28 @@ EMFORMER_RNNT_BASE_TEDLIUM3 ...@@ -24,3 +24,28 @@ EMFORMER_RNNT_BASE_TEDLIUM3
.. autodata:: EMFORMER_RNNT_BASE_TEDLIUM3 .. autodata:: EMFORMER_RNNT_BASE_TEDLIUM3
:no-value: :no-value:
Source Separation
-----------------
SourceSeparationBundle
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SourceSeparationBundle
.. automethod:: get_model
.. automethod:: sample_rate
CONVTASNET_BASE_LIBRI2MIX
~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: CONVTASNET_BASE_LIBRI2MIX
:no-value:
References
----------
.. footbibliography::
import os
import pytest import pytest
import torch import torch
import torchaudio import torchaudio
...@@ -62,7 +60,7 @@ def sample_speech(tmp_path, lang): ...@@ -62,7 +60,7 @@ def sample_speech(tmp_path, lang):
@pytest.fixture @pytest.fixture
def mixture_source(): def mixture_source():
path = torchaudio.utils.download_asset(os.path.join("test-assets", f"{_MIXTURE_FILE}")) path = torchaudio.utils.download_asset(f"test-assets/{_MIXTURE_FILE}")
return path return path
...@@ -70,7 +68,7 @@ def mixture_source(): ...@@ -70,7 +68,7 @@ def mixture_source():
def clean_sources(): def clean_sources():
paths = [] paths = []
for file in _CLEAN_FILES: for file in _CLEAN_FILES:
path = torchaudio.utils.download_asset(os.path.join("test-assets", f"{file}")) path = torchaudio.utils.download_asset(f"test-assets/{file}")
paths.append(path) paths.append(path)
return paths return paths
......
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3 from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .source_separation_pipeline import CONVTASNET_BASE_LIBRI2MIX from .source_separation_pipeline import CONVTASNET_BASE_LIBRI2MIX, SourceSeparationBundle
__all__ = [ __all__ = [
"CONVTASNET_BASE_LIBRI2MIX", "CONVTASNET_BASE_LIBRI2MIX",
"EMFORMER_RNNT_BASE_MUSTC", "EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3", "EMFORMER_RNNT_BASE_TEDLIUM3",
"SourceSeparationBundle",
] ]
...@@ -45,12 +45,14 @@ class SourceSeparationBundle: ...@@ -45,12 +45,14 @@ class SourceSeparationBundle:
@property @property
def sample_rate(self) -> int: def sample_rate(self) -> int:
"""Sample rate (in cycles per second) of input waveforms. """Sample rate of the audio that the model is trained on.
:type: int :type: int
""" """
return self._sample_rate return self._sample_rate
def get_model(self) -> torch.nn.Module: def get_model(self) -> torch.nn.Module:
"""Construct the model and load the pretrained weight."""
model = self._model_factory_func() model = self._model_factory_func()
path = torchaudio.utils.download_asset(self._model_path) path = torchaudio.utils.download_asset(self._model_path)
state_dict = torch.load(path) state_dict = torch.load(path)
...@@ -64,9 +66,12 @@ CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle( ...@@ -64,9 +66,12 @@ CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle(
_model_factory_func=partial(conv_tasnet_base, num_sources=2), _model_factory_func=partial(conv_tasnet_base, num_sources=2),
_sample_rate=8000, _sample_rate=8000,
) )
CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained ConvTasNet pipeline for source separation. CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained *ConvTasNet* [:footcite:`Luo_2019`] pipeline for source separation.
The underlying model is constructed by :py:func:`torchaudio.prototyoe.models.conv_tasnet_base` The underlying model is constructed by :py:func:`torchaudio.prototyoe.models.conv_tasnet_base`
and utilizes weights trained on Libri2Mix using training script ``lightning_train.py`` and utilizes weights trained on *Libri2Mix dataset* [:footcite:`cosentino2020librimix`] using training script
`here <https://github.com/pytorch/audio/tree/main/examples/source_separation/>`__ with default arguments. ``lightning_train.py`` `here <https://github.com/pytorch/audio/tree/release/0.12/examples/source_separation/>`__
with default arguments.
Please refer to :py:class:`SourceSeparationBundle` for usage instructions. Please refer to :py:class:`SourceSeparationBundle` for usage instructions.
""" """
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment