"git@developer.sourcefind.cn:norm/vllm.git" did not exist on "e5464ee484450c2671dd0226516c99c60ce70d9d"
Commit 6cee56ab authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add documents for SourceSeparationBundle (#2559)

Summary:
- Add documentation page for `SourceSeparationBundle` and `CONVTASNET_BASE_LIBRI2MIX`.
- Add citation of Libri2Mix dataset in the bundle documentation.
- url in integration test should use slash instead of `os.path.join` as it will fail on Windows. Change it to f-string.

Pull Request resolved: https://github.com/pytorch/audio/pull/2559

Reviewed By: carolineechen

Differential Revision: D38036116

Pulled By: nateanl

fbshipit-source-id: 736732805191113955badfec3955e2e24e8f4836
parent c18a103b
......@@ -24,3 +24,28 @@ EMFORMER_RNNT_BASE_TEDLIUM3
.. autodata:: EMFORMER_RNNT_BASE_TEDLIUM3
:no-value:
Source Separation
-----------------
SourceSeparationBundle
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SourceSeparationBundle
.. automethod:: get_model
.. automethod:: sample_rate
CONVTASNET_BASE_LIBRI2MIX
~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: CONVTASNET_BASE_LIBRI2MIX
:no-value:
References
----------
.. footbibliography::
import os
import pytest
import torch
import torchaudio
......@@ -62,7 +60,7 @@ def sample_speech(tmp_path, lang):
@pytest.fixture
def mixture_source():
path = torchaudio.utils.download_asset(os.path.join("test-assets", f"{_MIXTURE_FILE}"))
path = torchaudio.utils.download_asset(f"test-assets/{_MIXTURE_FILE}")
return path
......@@ -70,7 +68,7 @@ def mixture_source():
def clean_sources():
paths = []
for file in _CLEAN_FILES:
path = torchaudio.utils.download_asset(os.path.join("test-assets", f"{file}"))
path = torchaudio.utils.download_asset(f"test-assets/{file}")
paths.append(path)
return paths
......
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
from .source_separation_pipeline import CONVTASNET_BASE_LIBRI2MIX
from .source_separation_pipeline import CONVTASNET_BASE_LIBRI2MIX, SourceSeparationBundle
__all__ = [
"CONVTASNET_BASE_LIBRI2MIX",
"EMFORMER_RNNT_BASE_MUSTC",
"EMFORMER_RNNT_BASE_TEDLIUM3",
"SourceSeparationBundle",
]
......@@ -45,12 +45,14 @@ class SourceSeparationBundle:
@property
def sample_rate(self) -> int:
"""Sample rate (in cycles per second) of input waveforms.
"""Sample rate of the audio that the model is trained on.
:type: int
"""
return self._sample_rate
def get_model(self) -> torch.nn.Module:
"""Construct the model and load the pretrained weight."""
model = self._model_factory_func()
path = torchaudio.utils.download_asset(self._model_path)
state_dict = torch.load(path)
......@@ -64,9 +66,12 @@ CONVTASNET_BASE_LIBRI2MIX = SourceSeparationBundle(
_model_factory_func=partial(conv_tasnet_base, num_sources=2),
_sample_rate=8000,
)
CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained ConvTasNet pipeline for source separation.
CONVTASNET_BASE_LIBRI2MIX.__doc__ = """Pre-trained *ConvTasNet* [:footcite:`Luo_2019`] pipeline for source separation.
The underlying model is constructed by :py:func:`torchaudio.prototyoe.models.conv_tasnet_base`
and utilizes weights trained on Libri2Mix using training script ``lightning_train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/source_separation/>`__ with default arguments.
and utilizes weights trained on *Libri2Mix dataset* [:footcite:`cosentino2020librimix`] using training script
``lightning_train.py`` `here <https://github.com/pytorch/audio/tree/release/0.12/examples/source_separation/>`__
with default arguments.
Please refer to :py:class:`SourceSeparationBundle` for usage instructions.
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment