Commit 30c7077b authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Adopt `:autosummary:` in `torchaudio.models` module doc (#2690)

Summary:
* Introduce the mini-index at `torchaudio.models` page.

https://output.circle-artifacts.com/output/job/25e59810-3866-4ece-b1b7-8a10c7a2286d/artifacts/0/docs/models.html

<img width="1042" alt="Screen Shot 2022-09-20 at 1 20 50 PM" src="https://user-images.githubusercontent.com/855818/191166816-83314ad1-8b67-475b-aa10-d4cc59126295.png">

<img width="1048" alt="Screen Shot 2022-09-20 at 1 20 58 PM" src="https://user-images.githubusercontent.com/855818/191166829-1ceb65e0-9506-4328-9a2f-8b75b4e54404.png">

Pull Request resolved: https://github.com/pytorch/audio/pull/2690

Reviewed By: carolineechen

Differential Revision: D39654948

Pulled By: mthrok

fbshipit-source-id: 703d1526617596f647c85a7148f41ca55fffdbc8
parent 4a65b050
..
autogenerated from source/_templates/autosummary/model_class.rst
{%- set methods=["forward"] %}
{%- if name in ["Wav2Vec2Model"] %}
{{ methods.extend(["extract_features"]) }}
{%- elif name in ["Emformer", "RNNTBeamSearch", "WaveRNN", "Tacotron2", ] %}
{{ methods.extend(["infer"]) }}
{%- elif name == "RNNT" %}
{{ methods.extend(["transcribe_streaming", "transcribe", "predict", "join"]) }}
{%- endif %}
{{ name | underline }}
.. autoclass:: {{ fullname }}
{% for item in methods %}
{{item | underline("-") }}
.. container:: py attribute
.. automethod:: {{[fullname, item] | join('.')}}
{%- endfor %}
{%- if name == "RNNTBeamSearch" %}
Support Structures
==================
Hypothesis
----------
.. container:: py attribute
.. autodata:: torchaudio.models.Hypothesis
:no-value:
{%- endif %}
.. role:: hidden .. py:module:: torchaudio.models
:class: hidden-section
torchaudio.models torchaudio.models
================= =================
.. py:module:: torchaudio.models
.. currentmodule:: torchaudio.models .. currentmodule:: torchaudio.models
The models subpackage contains definitions of models for addressing common audio tasks. The ``torchaudio.models`` subpackage contains definitions of models for addressing common audio tasks.
Conformer
~~~~~~~~~
.. autoclass:: Conformer
.. automethod:: forward
ConvTasNet
~~~~~~~~~~
Model For pre-trained models, please refer to :mod:`torchaudio.pipelines` module.
-----
ConvTasNet Model Definitions
^^^^^^^^^^
.. autoclass:: ConvTasNet
.. automethod:: forward
Factory Functions
----------------- -----------------
conv_tasnet_base Model defintions are responsible for constructing computation graphs and executing them.
^^^^^^^^^^^^^^^^
Some models have complex structure and variations.
.. autofunction:: conv_tasnet_base For such models, `Factory Functions`_ are provided.
DeepSpeech .. autosummary::
~~~~~~~~~~ :toctree: generated
:nosignatures:
.. autoclass:: DeepSpeech :template: autosummary/model_class.rst
.. automethod:: forward Conformer
ConvTasNet
Emformer DeepSpeech
~~~~~~~~ Emformer
HDemucs
.. autoclass:: Emformer HuBERTPretrainModel
RNNT
.. automethod:: forward RNNTBeamSearch
Tacotron2
.. automethod:: infer Wav2Letter
Wav2Vec2Model
Hybrid Demucs WaveRNN
~~~~~~~~~~~~~
Model
-----
HDemucs
^^^^^^^
.. autoclass:: HDemucs
.. automethod:: forward
Factory Functions
-----------------
hdemucs_low
^^^^^^^^^^^
.. autofunction:: hdemucs_low
hdemucs_medium
^^^^^^^^^^^^^^
.. autofunction:: hdemucs_medium
hdemucs_high
^^^^^^^^^^^^
.. autofunction:: hdemucs_high
RNN-T
~~~~~
Model
-----
RNNT
^^^^
.. autoclass:: RNNT
.. automethod:: forward
.. automethod:: transcribe_streaming
.. automethod:: transcribe
.. automethod:: predict
.. automethod:: join
Factory Functions
-----------------
emformer_rnnt_model
^^^^^^^^^^^^^^^^^^^
.. autofunction:: emformer_rnnt_model
emformer_rnnt_base
^^^^^^^^^^^^^^^^^^
.. autofunction:: emformer_rnnt_base
Decoder
-------
RNNTBeamSearch
^^^^^^^^^^^^^^
.. autoclass:: RNNTBeamSearch
.. automethod:: forward
.. automethod:: infer
Hypothesis
^^^^^^^^^^
.. container:: py attribute
.. autodata:: Hypothesis
:no-value:
Tacotron2
~~~~~~~~~
.. autoclass:: Tacotron2
.. automethod:: forward
.. automethod:: infer
Wav2Letter
~~~~~~~~~~
.. autoclass:: Wav2Letter
.. automethod:: forward
Wav2Vec2.0 / HuBERT
~~~~~~~~~~~~~~~~~~~
Model
-----
Wav2Vec2Model
^^^^^^^^^^^^^
.. autoclass:: Wav2Vec2Model
.. automethod:: extract_features
.. automethod:: forward
HuBERTPretrainModel
^^^^^^^^^^^^^^^^^^^
.. autoclass:: HuBERTPretrainModel
.. automethod:: forward
Factory Functions Factory Functions
----------------- -----------------
wav2vec2_model .. autosummary::
^^^^^^^^^^^^^^ :toctree: generated
:nosignatures:
.. autofunction:: wav2vec2_model
conv_tasnet_base
emformer_rnnt_model
wav2vec2_base emformer_rnnt_base
^^^^^^^^^^^^^ wav2vec2_model
wav2vec2_base
.. autofunction:: wav2vec2_base wav2vec2_large
wav2vec2_large_lv60k
wav2vec2_large hubert_base
^^^^^^^^^^^^^^ hubert_large
hubert_xlarge
.. autofunction:: wav2vec2_large hubert_pretrain_model
hubert_pretrain_base
wav2vec2_large_lv60k hubert_pretrain_large
^^^^^^^^^^^^^^^^^^^^ hubert_pretrain_xlarge
hdemucs_low
.. autofunction:: wav2vec2_large_lv60k hdemucs_medium
hdemucs_high
hubert_base
^^^^^^^^^^^
.. autofunction:: hubert_base
hubert_large
^^^^^^^^^^^^
.. autofunction:: hubert_large
hubert_xlarge
^^^^^^^^^^^^^
.. autofunction:: hubert_xlarge
hubert_pretrain_model
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_model
hubert_pretrain_base
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_base
hubert_pretrain_large
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_large
hubert_pretrain_xlarge
^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_xlarge
Utility Functions Utility Functions
----------------- -----------------
.. currentmodule:: torchaudio.models.wav2vec2.utils .. autosummary::
:toctree: generated
import_huggingface_model :nosignatures:
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: import_huggingface_model
import_fairseq_model
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: import_fairseq_model
.. currentmodule:: torchaudio.models
WaveRNN
~~~~~~~
.. autoclass:: WaveRNN
.. automethod:: forward
.. automethod:: infer ~wav2vec2.utils.import_fairseq_model
~wav2vec2.utils.import_huggingface_model
@article{wavernn,
author = {Nal Kalchbrenner and
Erich Elsen and
Karen Simonyan and
Seb Noury and
Norman Casagrande and
Edward Lockhart and
Florian Stimberg and
A{\"{a}}ron van den Oord and
Sander Dieleman and
Koray Kavukcuoglu},
title = {Efficient Neural Audio Synthesis},
journal = {CoRR},
volume = {abs/1802.08435},
year = {2018},
url = {http://arxiv.org/abs/1802.08435},
eprinttype = {arXiv},
eprint = {1802.08435},
timestamp = {Mon, 13 Aug 2018 16:47:01 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-1802-08435.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@misc{RESAMPLE, @misc{RESAMPLE,
author = {Julius O. Smith}, author = {Julius O. Smith},
title = {Digital Audio Resampling Home Page "Theory of Ideal Bandlimited Interpolation" section}, title = {Digital Audio Resampling Home Page "Theory of Ideal Bandlimited Interpolation" section},
......
...@@ -21,10 +21,12 @@ perform music separation ...@@ -21,10 +21,12 @@ perform music separation
# 3. Collect output chunks and combine according to the way they have been # 3. Collect output chunks and combine according to the way they have been
# overlapped. # overlapped.
# #
# The `Hybrid Demucs <https://arxiv.org/pdf/2111.03600.pdf>`__ model is a developed version of the # The Hybrid Demucs [`Défossez, 2021 <https://arxiv.org/abs/2111.03600>`__]
# model is a developed version of the
# `Demucs <https://github.com/facebookresearch/demucs>`__ model, a # `Demucs <https://github.com/facebookresearch/demucs>`__ model, a
# waveform based model which separates music into its # waveform based model which separates music into its
# respective sources, such as vocals, bass, and drums. Hybrid Demucs effectively uses spectrogram to learn # respective sources, such as vocals, bass, and drums.
# Hybrid Demucs effectively uses spectrogram to learn
# through the frequency domain and also moves to time convolutions. # through the frequency domain and also moves to time convolutions.
# #
...@@ -81,7 +83,7 @@ except ModuleNotFoundError: ...@@ -81,7 +83,7 @@ except ModuleNotFoundError:
# #
# Pre-trained model weights and related pipeline components are bundled as # Pre-trained model weights and related pipeline components are bundled as
# :py:func:`torchaudio.pipelines.HDEMUCS_HIGH_MUSDB_PLUS`. This is a # :py:func:`torchaudio.pipelines.HDEMUCS_HIGH_MUSDB_PLUS`. This is a
# HDemucs model trained on # :py:class:`torchaudio.models.HDemucs` model trained on
# `MUSDB18-HQ <https://zenodo.org/record/3338373>`__ and additional # `MUSDB18-HQ <https://zenodo.org/record/3338373>`__ and additional
# internal extra training data. # internal extra training data.
# This specific model is suited for higher sample rates, around 44.1 kHZ # This specific model is suited for higher sample rates, around 44.1 kHZ
......
...@@ -299,8 +299,14 @@ class _HDecLayer(torch.nn.Module): ...@@ -299,8 +299,14 @@ class _HDecLayer(torch.nn.Module):
class HDemucs(torch.nn.Module): class HDemucs(torch.nn.Module):
r""" r"""Hybrid Demucs model from
Hybrid Demucs model from *Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`. *Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`.
See Also:
* :func:`~torchaudio.models.hdemucs_low`,
:func:`~torchaudio.models.hdemucs_medium`,
:func:`~torchaudio.models.hdemucs_high`: factory functions.
* :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
Args: Args:
sources (List[str]): list of source names. List can contain the following source sources (List[str]): list of source names. List can contain the following source
...@@ -959,8 +965,7 @@ def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = ...@@ -959,8 +965,7 @@ def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int =
def hdemucs_low(sources: List[str]) -> HDemucs: def hdemucs_low(sources: List[str]) -> HDemucs:
r"""Builds low nfft (1024) version of HDemucs model. This version is suitable for lower sample rates, and bundles """Builds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz.
parameters together to call valid nfft and depth values for a model structured for sample rates around 8 kHZ.
Args: Args:
sources (List[str]): See :py:func:`HDemucs`. sources (List[str]): See :py:func:`HDemucs`.
...@@ -974,8 +979,7 @@ def hdemucs_low(sources: List[str]) -> HDemucs: ...@@ -974,8 +979,7 @@ def hdemucs_low(sources: List[str]) -> HDemucs:
def hdemucs_medium(sources: List[str]) -> HDemucs: def hdemucs_medium(sources: List[str]) -> HDemucs:
r"""Builds medium nfft (2048) version of HDemucs model. This version is suitable for medium sample rates,and bundles r"""Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz.
parameters together to call valid nfft and depth values for a model structured for sample rates around 16-32 kHZ
.. note:: .. note::
...@@ -994,9 +998,7 @@ def hdemucs_medium(sources: List[str]) -> HDemucs: ...@@ -994,9 +998,7 @@ def hdemucs_medium(sources: List[str]) -> HDemucs:
def hdemucs_high(sources: List[str]) -> HDemucs: def hdemucs_high(sources: List[str]) -> HDemucs:
r"""Builds high nfft (4096) version of HDemucs model. This version is suitable for high/standard music sample rates, r"""Builds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz.
and bundles parameters together to call valid nfft and depth values for a model structured for sample rates around
44.1-48 kHZ
Args: Args:
sources (List[str]): See :py:func:`HDemucs`. sources (List[str]): See :py:func:`HDemucs`.
......
...@@ -213,7 +213,7 @@ class ConformerLayer(torch.nn.Module): ...@@ -213,7 +213,7 @@ class ConformerLayer(torch.nn.Module):
class Conformer(torch.nn.Module): class Conformer(torch.nn.Module):
r"""Implements the Conformer architecture introduced in r"""Conformer architecture introduced in
*Conformer: Convolution-augmented Transformer for Speech Recognition* *Conformer: Convolution-augmented Transformer for Speech Recognition*
:cite:`gulati2020conformer`. :cite:`gulati2020conformer`.
......
...@@ -160,10 +160,17 @@ class MaskGenerator(torch.nn.Module): ...@@ -160,10 +160,17 @@ class MaskGenerator(torch.nn.Module):
class ConvTasNet(torch.nn.Module): class ConvTasNet(torch.nn.Module):
"""Conv-TasNet: a fully-convolutional time-domain audio separation network """Conv-TasNet architecture introduced in
*Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation* *Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
:cite:`Luo_2019`. :cite:`Luo_2019`.
Note:
This implementation corresponds to the "non-causal" setting in the paper.
See Also:
* :func:`~torchaudio.models.conv_tasnet_base`: A factory function.
* :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
Args: Args:
num_sources (int, optional): The number of sources to split. num_sources (int, optional): The number of sources to split.
enc_kernel_size (int, optional): The convolution kernel size of the encoder/decoder, <L>. enc_kernel_size (int, optional): The convolution kernel size of the encoder/decoder, <L>.
...@@ -174,9 +181,6 @@ class ConvTasNet(torch.nn.Module): ...@@ -174,9 +181,6 @@ class ConvTasNet(torch.nn.Module):
msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, <X>. msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, <X>.
msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, <R>. msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, <R>.
msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``). msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``).
Note:
This implementation corresponds to the "non-causal" setting in the paper.
""" """
def __init__( def __init__(
...@@ -302,9 +306,7 @@ class ConvTasNet(torch.nn.Module): ...@@ -302,9 +306,7 @@ class ConvTasNet(torch.nn.Module):
def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet: def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet:
r"""Builds the non-causal version of ConvTasNet in r"""Builds non-causal version of :class:`~torchaudio.models.ConvTasNet`.
*Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
:cite:`Luo_2019`.
The parameter settings follow the ones with the highest Si-SNR metirc score in the paper, The parameter settings follow the ones with the highest Si-SNR metirc score in the paper,
except the mask activation function is changed from "sigmoid" to "relu" for performance improvement. except the mask activation function is changed from "sigmoid" to "relu" for performance improvement.
......
...@@ -26,9 +26,8 @@ class FullyConnected(torch.nn.Module): ...@@ -26,9 +26,8 @@ class FullyConnected(torch.nn.Module):
class DeepSpeech(torch.nn.Module): class DeepSpeech(torch.nn.Module):
""" """DeepSpeech architecture introduced in
DeepSpeech model architecture from *Deep Speech: Scaling up end-to-end speech recognition* *Deep Speech: Scaling up end-to-end speech recognition* :cite:`hannun2014deep`.
:cite:`hannun2014deep`.
Args: Args:
n_feature: Number of input features n_feature: Number of input features
......
...@@ -804,10 +804,15 @@ class _EmformerImpl(torch.nn.Module): ...@@ -804,10 +804,15 @@ class _EmformerImpl(torch.nn.Module):
class Emformer(_EmformerImpl): class Emformer(_EmformerImpl):
r"""Implements the Emformer architecture introduced in r"""Emformer architecture introduced in
*Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition* *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition*
:cite:`shi2021emformer`. :cite:`shi2021emformer`.
See Also:
* :func:`~torchaudio.models.emformer_rnnt_model`,
:func:`~torchaudio.models.emformer_rnnt_base`: factory functions.
* :class:`torchaudio.pipelines.RNNTBundle`: ASR pipelines with pretrained model.
Args: Args:
input_dim (int): input dimension. input_dim (int): input dimension.
num_heads (int): number of attention heads in each Emformer layer. num_heads (int): number of attention heads in each Emformer layer.
......
...@@ -456,7 +456,11 @@ class RNNT(torch.nn.Module): ...@@ -456,7 +456,11 @@ class RNNT(torch.nn.Module):
Recurrent neural network transducer (RNN-T) model. Recurrent neural network transducer (RNN-T) model.
Note: Note:
To build the model, please use one of the factory functions. To build the model, please use one of the factory functions,
:py:func:`emformer_rnnt_model` or :py:func:`emformer_rnnt_base`.
See Also:
:class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pre-trained models.
Args: Args:
transcriber (torch.nn.Module): transcription network. transcriber (torch.nn.Module): transcription network.
...@@ -706,7 +710,7 @@ def emformer_rnnt_model( ...@@ -706,7 +710,7 @@ def emformer_rnnt_model(
lstm_layer_norm_epsilon: float, lstm_layer_norm_epsilon: float,
lstm_dropout: float, lstm_dropout: float,
) -> RNNT: ) -> RNNT:
r"""Builds Emformer-based recurrent neural network transducer (RNN-T) model. r"""Builds Emformer-based :class:`~torchaudio.models.RNNT`.
Note: Note:
For non-streaming inference, the expectation is for `transcribe` to be called on input For non-streaming inference, the expectation is for `transcribe` to be called on input
...@@ -779,7 +783,7 @@ def emformer_rnnt_model( ...@@ -779,7 +783,7 @@ def emformer_rnnt_model(
def emformer_rnnt_base(num_symbols: int) -> RNNT: def emformer_rnnt_base(num_symbols: int) -> RNNT:
r"""Builds basic version of Emformer RNN-T model. r"""Builds basic version of Emformer-based :class:`~torchaudio.models.RNNT`.
Args: Args:
num_symbols (int): The size of target token lexicon. num_symbols (int): The size of target token lexicon.
......
...@@ -75,6 +75,9 @@ def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None: ...@@ -75,6 +75,9 @@ def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
class RNNTBeamSearch(torch.nn.Module): class RNNTBeamSearch(torch.nn.Module):
r"""Beam search decoder for RNN-T model. r"""Beam search decoder for RNN-T model.
See Also:
* :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pretrained model.
Args: Args:
model (RNNT): RNN-T model to use. model (RNNT): RNN-T model to use.
blank (int): index of blank token in vocabulary. blank (int): index of blank token in vocabulary.
......
...@@ -867,12 +867,12 @@ class _Decoder(nn.Module): ...@@ -867,12 +867,12 @@ class _Decoder(nn.Module):
class Tacotron2(nn.Module): class Tacotron2(nn.Module):
r"""Tacotron2 model based on the implementation from r"""Tacotron2 model from *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
`Nvidia <https://github.com/NVIDIA/DeepLearningExamples/>`_. :cite:`shen2018natural` based on the implementation from
`Nvidia Deep Learning Examples <https://github.com/NVIDIA/DeepLearningExamples/>`_.
The original implementation was introduced in See Also:
*Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions* * :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
:cite:`shen2018natural`.
Args: Args:
mask_padding (bool, optional): Use mask padding (Default: ``False``). mask_padding (bool, optional): Use mask padding (Default: ``False``).
......
...@@ -9,7 +9,8 @@ class Wav2Letter(nn.Module): ...@@ -9,7 +9,8 @@ class Wav2Letter(nn.Module):
r"""Wav2Letter model architecture from *Wav2Letter: an End-to-End ConvNet-based Speech r"""Wav2Letter model architecture from *Wav2Letter: an End-to-End ConvNet-based Speech
Recognition System* :cite:`collobert2016wav2letter`. Recognition System* :cite:`collobert2016wav2letter`.
:math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}` See Also:
* `Training example <https://github.com/pytorch/audio/tree/release/0.12/examples/pipeline_wav2letter>`__
Args: Args:
num_classes (int, optional): Number of classes to be classified. (Default: ``40``) num_classes (int, optional): Number of classes to be classified. (Default: ``40``)
...@@ -19,7 +20,7 @@ class Wav2Letter(nn.Module): ...@@ -19,7 +20,7 @@ class Wav2Letter(nn.Module):
""" """
def __init__(self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1) -> None: def __init__(self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1) -> None:
super(Wav2Letter, self).__init__() super().__init__()
acoustic_num_features = 250 if input_type == "waveform" else num_features acoustic_num_features = 250 if input_type == "waveform" else num_features
acoustic_model = nn.Sequential( acoustic_model = nn.Sequential(
......
...@@ -8,10 +8,17 @@ from . import components ...@@ -8,10 +8,17 @@ from . import components
class Wav2Vec2Model(Module): class Wav2Vec2Model(Module):
"""Encoder model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`. """Acoustic model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`.
Note: Note:
To build the model, please use one of the factory functions. To build the model, please use one of the factory functions.
:py:func:`wav2vec2_model`, :py:func:`wav2vec2_base`, :py:func:`wav2vec2_large`,
:py:func:`wav2vec2_large_lv60k`, :py:func:`hubert_base`, :py:func:`hubert_large`,
and :py:func:`hubert_xlarge`.
See Also:
* :class:`torchaudio.pipelines.Wav2Vec2Bundle`: Pretrained models (without fine-tuning)
* :class:`torchaudio.pipelines.Wav2Vec2ASRBundle`: ASR pipelines with pretrained models.
Args: Args:
feature_extractor (torch.nn.Module): feature_extractor (torch.nn.Module):
...@@ -116,11 +123,18 @@ class Wav2Vec2Model(Module): ...@@ -116,11 +123,18 @@ class Wav2Vec2Model(Module):
class HuBERTPretrainModel(Module): class HuBERTPretrainModel(Module):
"""HuBERT pre-train model for training from scratch. """HuBERTPretrainModel()
HuBERT model used for pretraining in *HuBERT* :cite:`hsu2021hubert`.
Note: Note:
To build the model, please use one of the factory functions in To build the model, please use one of the factory functions,
`[hubert_pretrain_base, hubert_pretrain_large, hubert_pretrain_xlarge]`. :py:func:`hubert_pretrain_base`, :py:func:`hubert_pretrain_large`
or :py:func:`hubert_pretrain_xlarge`.
See Also:
`HuBERT Pre-training and Fine-tuning Examples
<https://github.com/pytorch/audio/tree/release/0.12/examples/hubert>`__
Args: Args:
feature_extractor (torch.nn.Module): feature_extractor (torch.nn.Module):
...@@ -235,7 +249,7 @@ def wav2vec2_model( ...@@ -235,7 +249,7 @@ def wav2vec2_model(
encoder_layer_drop: float, encoder_layer_drop: float,
aux_num_out: Optional[int], aux_num_out: Optional[int],
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
"""Build a custom Wav2Vec2Model """Builds custom :class:`~torchaudio.models.Wav2Vec2Model`.
Note: Note:
The "feature extractor" below corresponds to The "feature extractor" below corresponds to
...@@ -391,7 +405,7 @@ def wav2vec2_base( ...@@ -391,7 +405,7 @@ def wav2vec2_base(
encoder_layer_drop: float = 0.1, encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
"""Build Wav2Vec2Model with "base" architecture from *wav2vec 2.0* :cite:`baevski2020wav2vec` """Builds "base" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
Args: Args:
encoder_projection_dropout (float): encoder_projection_dropout (float):
...@@ -439,7 +453,7 @@ def wav2vec2_large( ...@@ -439,7 +453,7 @@ def wav2vec2_large(
encoder_layer_drop: float = 0.1, encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
"""Build Wav2Vec2Model with "large" architecture from *wav2vec 2.0* :cite:`baevski2020wav2vec` """Builds "large" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
Args: Args:
encoder_projection_dropout (float): encoder_projection_dropout (float):
...@@ -487,7 +501,7 @@ def wav2vec2_large_lv60k( ...@@ -487,7 +501,7 @@ def wav2vec2_large_lv60k(
encoder_layer_drop: float = 0.1, encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
"""Build Wav2Vec2Model with "large lv-60k" architecture from *wav2vec 2.0* :cite:`baevski2020wav2vec` """Builds "large lv-60k" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
Args: Args:
encoder_projection_dropout (float): encoder_projection_dropout (float):
...@@ -535,7 +549,7 @@ def hubert_base( ...@@ -535,7 +549,7 @@ def hubert_base(
encoder_layer_drop: float = 0.05, encoder_layer_drop: float = 0.05,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
"""Build HuBERT model with "base" architecture from *HuBERT* :cite:`hsu2021hubert` """Builds "base" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
Args: Args:
encoder_projection_dropout (float): encoder_projection_dropout (float):
...@@ -583,7 +597,7 @@ def hubert_large( ...@@ -583,7 +597,7 @@ def hubert_large(
encoder_layer_drop: float = 0.0, encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
"""Build HuBERT model with "large" architecture from *HuBERT* :cite:`hsu2021hubert` """Builds "large" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
Args: Args:
encoder_projection_dropout (float): encoder_projection_dropout (float):
...@@ -631,7 +645,7 @@ def hubert_xlarge( ...@@ -631,7 +645,7 @@ def hubert_xlarge(
encoder_layer_drop: float = 0.0, encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None, aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model: ) -> Wav2Vec2Model:
"""Build HuBERT model with "extra large" architecture from *HuBERT* :cite:`hsu2021hubert` """Builds "extra large" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
Args: Args:
encoder_projection_dropout (float): encoder_projection_dropout (float):
...@@ -705,7 +719,7 @@ def hubert_pretrain_model( ...@@ -705,7 +719,7 @@ def hubert_pretrain_model(
final_dim: int, final_dim: int,
feature_grad_mult: Optional[float], feature_grad_mult: Optional[float],
) -> HuBERTPretrainModel: ) -> HuBERTPretrainModel:
"""Build a custom HuBERTPretrainModel for training from scratch """Builds custom :class:`HuBERTPretrainModel` for training from scratch
Note: Note:
The "feature extractor" below corresponds to The "feature extractor" below corresponds to
...@@ -973,7 +987,7 @@ def hubert_pretrain_base( ...@@ -973,7 +987,7 @@ def hubert_pretrain_base(
feature_grad_mult: Optional[float] = 0.1, feature_grad_mult: Optional[float] = 0.1,
num_classes: int = 100, num_classes: int = 100,
) -> HuBERTPretrainModel: ) -> HuBERTPretrainModel:
"""Build HuBERTPretrainModel model with "base" architecture from *HuBERT* :cite:`hsu2021hubert` """Builds "base" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
Args: Args:
encoder_projection_dropout (float): encoder_projection_dropout (float):
...@@ -1048,7 +1062,7 @@ def hubert_pretrain_large( ...@@ -1048,7 +1062,7 @@ def hubert_pretrain_large(
mask_channel_length: int = 10, mask_channel_length: int = 10,
feature_grad_mult: Optional[float] = None, feature_grad_mult: Optional[float] = None,
) -> HuBERTPretrainModel: ) -> HuBERTPretrainModel:
"""Build HuBERTPretrainModel model for pre-training with "large" architecture from *HuBERT* :cite:`hsu2021hubert` """Builds "large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
Args: Args:
encoder_projection_dropout (float): encoder_projection_dropout (float):
...@@ -1121,7 +1135,7 @@ def hubert_pretrain_xlarge( ...@@ -1121,7 +1135,7 @@ def hubert_pretrain_xlarge(
mask_channel_length: int = 10, mask_channel_length: int = 10,
feature_grad_mult: Optional[float] = None, feature_grad_mult: Optional[float] = None,
) -> HuBERTPretrainModel: ) -> HuBERTPretrainModel:
"""Build HuBERTPretrainModel model for pre-training with "extra large" architecture from *HuBERT* :cite:`hsu2021hubert` """Builds "extra large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
Args: Args:
encoder_projection_dropout (float): encoder_projection_dropout (float):
......
...@@ -125,7 +125,8 @@ def _convert_state_dict(state_dict): ...@@ -125,7 +125,8 @@ def _convert_state_dict(state_dict):
def import_fairseq_model(original: Module) -> Wav2Vec2Model: def import_fairseq_model(original: Module) -> Wav2Vec2Model:
"""Build Wav2Vec2Model from the corresponding model object of `fairseq`_. """Builds :class:`Wav2Vec2Model` from the corresponding model object of
`fairseq <https://github.com/pytorch/fairseq>`_.
Args: Args:
original (torch.nn.Module): original (torch.nn.Module):
...@@ -171,8 +172,6 @@ def import_fairseq_model(original: Module) -> Wav2Vec2Model: ...@@ -171,8 +172,6 @@ def import_fairseq_model(original: Module) -> Wav2Vec2Model:
>>> mask = torch.zeros_like(waveform) >>> mask = torch.zeros_like(waveform)
>>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1) >>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1)
>>> torch.testing.assert_allclose(emission, reference) >>> torch.testing.assert_allclose(emission, reference)
.. _fairseq: https://github.com/pytorch/fairseq
""" """
class_ = original.__class__.__name__ class_ = original.__class__.__name__
if class_ == "Wav2Vec2Model": if class_ == "Wav2Vec2Model":
......
...@@ -48,7 +48,8 @@ def _build(config, original): ...@@ -48,7 +48,8 @@ def _build(config, original):
def import_huggingface_model(original: Module) -> Wav2Vec2Model: def import_huggingface_model(original: Module) -> Wav2Vec2Model:
"""Build Wav2Vec2Model from the corresponding model object of Hugging Face's `Transformers`_. """Builds :class:`Wav2Vec2Model` from the corresponding model object of
`Transformers <https://huggingface.co/transformers/>`_.
Args: Args:
original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``. original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``.
...@@ -64,8 +65,6 @@ def import_huggingface_model(original: Module) -> Wav2Vec2Model: ...@@ -64,8 +65,6 @@ def import_huggingface_model(original: Module) -> Wav2Vec2Model:
>>> >>>
>>> waveforms, _ = torchaudio.load("audio.wav") >>> waveforms, _ = torchaudio.load("audio.wav")
>>> logits, _ = model(waveforms) >>> logits, _ = model(waveforms)
.. _Transformers: https://huggingface.co/transformers/
""" """
_LG.info("Importing model.") _LG.info("Importing model.")
_LG.info("Loading model configuration.") _LG.info("Loading model configuration.")
......
...@@ -197,12 +197,17 @@ class UpsampleNetwork(nn.Module): ...@@ -197,12 +197,17 @@ class UpsampleNetwork(nn.Module):
class WaveRNN(nn.Module): class WaveRNN(nn.Module):
r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_. r"""WaveRNN model from *Efficient Neural Audio Synthesis* :cite:`wavernn`
based on the implementation from `fatchord/WaveRNN <https://github.com/fatchord/WaveRNN>`_.
The original implementation was introduced in *Efficient Neural Audio Synthesis* The original implementation was introduced in *Efficient Neural Audio Synthesis*
:cite:`kalchbrenner2018efficient`. The input channels of waveform and spectrogram have to be 1. :cite:`kalchbrenner2018efficient`. The input channels of waveform and spectrogram have to be 1.
The product of `upsample_scales` must equal `hop_length`. The product of `upsample_scales` must equal `hop_length`.
See Also:
* `Training example <https://github.com/pytorch/audio/tree/release/0.12/examples/pipeline_wavernn>`__
* :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
Args: Args:
upsample_scales: the list of upsample scales. upsample_scales: the list of upsample scales.
n_classes: the number of output classes. n_classes: the number of output classes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment