Commit 30c7077b authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Adopt `:autosummary:` in `torchaudio.models` module doc (#2690)

Summary:
* Introduce the mini-index at `torchaudio.models` page.

https://output.circle-artifacts.com/output/job/25e59810-3866-4ece-b1b7-8a10c7a2286d/artifacts/0/docs/models.html

<img width="1042" alt="Screen Shot 2022-09-20 at 1 20 50 PM" src="https://user-images.githubusercontent.com/855818/191166816-83314ad1-8b67-475b-aa10-d4cc59126295.png">

<img width="1048" alt="Screen Shot 2022-09-20 at 1 20 58 PM" src="https://user-images.githubusercontent.com/855818/191166829-1ceb65e0-9506-4328-9a2f-8b75b4e54404.png">

Pull Request resolved: https://github.com/pytorch/audio/pull/2690

Reviewed By: carolineechen

Differential Revision: D39654948

Pulled By: mthrok

fbshipit-source-id: 703d1526617596f647c85a7148f41ca55fffdbc8
parent 4a65b050
..
autogenerated from source/_templates/autosummary/model_class.rst
{%- set methods=["forward"] %}
{%- if name in ["Wav2Vec2Model"] %}
{{ methods.extend(["extract_features"]) }}
{%- elif name in ["Emformer", "RNNTBeamSearch", "WaveRNN", "Tacotron2", ] %}
{{ methods.extend(["infer"]) }}
{%- elif name == "RNNT" %}
{{ methods.extend(["transcribe_streaming", "transcribe", "predict", "join"]) }}
{%- endif %}
{{ name | underline }}
.. autoclass:: {{ fullname }}
{% for item in methods %}
{{item | underline("-") }}
.. container:: py attribute
.. automethod:: {{[fullname, item] | join('.')}}
{%- endfor %}
{%- if name == "RNNTBeamSearch" %}
Support Structures
==================
Hypothesis
----------
.. container:: py attribute
.. autodata:: torchaudio.models.Hypothesis
:no-value:
{%- endif %}
.. role:: hidden
:class: hidden-section
.. py:module:: torchaudio.models
torchaudio.models
=================
.. py:module:: torchaudio.models
.. currentmodule:: torchaudio.models
The models subpackage contains definitions of models for addressing common audio tasks.
Conformer
~~~~~~~~~
.. autoclass:: Conformer
.. automethod:: forward
ConvTasNet
~~~~~~~~~~
The ``torchaudio.models`` subpackage contains definitions of models for addressing common audio tasks.
Model
-----
For pre-trained models, please refer to :mod:`torchaudio.pipelines` module.
ConvTasNet
^^^^^^^^^^
.. autoclass:: ConvTasNet
.. automethod:: forward
Factory Functions
Model Definitions
-----------------
conv_tasnet_base
^^^^^^^^^^^^^^^^
.. autofunction:: conv_tasnet_base
DeepSpeech
~~~~~~~~~~
.. autoclass:: DeepSpeech
.. automethod:: forward
Emformer
~~~~~~~~
.. autoclass:: Emformer
.. automethod:: forward
.. automethod:: infer
Hybrid Demucs
~~~~~~~~~~~~~
Model
-----
HDemucs
^^^^^^^
.. autoclass:: HDemucs
.. automethod:: forward
Factory Functions
-----------------
hdemucs_low
^^^^^^^^^^^
.. autofunction:: hdemucs_low
hdemucs_medium
^^^^^^^^^^^^^^
.. autofunction:: hdemucs_medium
hdemucs_high
^^^^^^^^^^^^
.. autofunction:: hdemucs_high
RNN-T
~~~~~
Model
-----
RNNT
^^^^
.. autoclass:: RNNT
.. automethod:: forward
.. automethod:: transcribe_streaming
.. automethod:: transcribe
.. automethod:: predict
.. automethod:: join
Factory Functions
-----------------
emformer_rnnt_model
^^^^^^^^^^^^^^^^^^^
.. autofunction:: emformer_rnnt_model
emformer_rnnt_base
^^^^^^^^^^^^^^^^^^
.. autofunction:: emformer_rnnt_base
Decoder
-------
RNNTBeamSearch
^^^^^^^^^^^^^^
.. autoclass:: RNNTBeamSearch
.. automethod:: forward
.. automethod:: infer
Hypothesis
^^^^^^^^^^
.. container:: py attribute
.. autodata:: Hypothesis
:no-value:
Tacotron2
~~~~~~~~~
.. autoclass:: Tacotron2
.. automethod:: forward
.. automethod:: infer
Wav2Letter
~~~~~~~~~~
.. autoclass:: Wav2Letter
.. automethod:: forward
Wav2Vec2.0 / HuBERT
~~~~~~~~~~~~~~~~~~~
Model
-----
Wav2Vec2Model
^^^^^^^^^^^^^
.. autoclass:: Wav2Vec2Model
.. automethod:: extract_features
.. automethod:: forward
HuBERTPretrainModel
^^^^^^^^^^^^^^^^^^^
.. autoclass:: HuBERTPretrainModel
.. automethod:: forward
Model defintions are responsible for constructing computation graphs and executing them.
Some models have complex structure and variations.
For such models, `Factory Functions`_ are provided.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/model_class.rst
Conformer
ConvTasNet
DeepSpeech
Emformer
HDemucs
HuBERTPretrainModel
RNNT
RNNTBeamSearch
Tacotron2
Wav2Letter
Wav2Vec2Model
WaveRNN
Factory Functions
-----------------
wav2vec2_model
^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_model
wav2vec2_base
^^^^^^^^^^^^^
.. autofunction:: wav2vec2_base
wav2vec2_large
^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_large
wav2vec2_large_lv60k
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: wav2vec2_large_lv60k
hubert_base
^^^^^^^^^^^
.. autofunction:: hubert_base
hubert_large
^^^^^^^^^^^^
.. autofunction:: hubert_large
hubert_xlarge
^^^^^^^^^^^^^
.. autofunction:: hubert_xlarge
hubert_pretrain_model
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_model
hubert_pretrain_base
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_base
hubert_pretrain_large
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_large
hubert_pretrain_xlarge
^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: hubert_pretrain_xlarge
.. autosummary::
:toctree: generated
:nosignatures:
conv_tasnet_base
emformer_rnnt_model
emformer_rnnt_base
wav2vec2_model
wav2vec2_base
wav2vec2_large
wav2vec2_large_lv60k
hubert_base
hubert_large
hubert_xlarge
hubert_pretrain_model
hubert_pretrain_base
hubert_pretrain_large
hubert_pretrain_xlarge
hdemucs_low
hdemucs_medium
hdemucs_high
Utility Functions
-----------------
.. currentmodule:: torchaudio.models.wav2vec2.utils
import_huggingface_model
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: import_huggingface_model
import_fairseq_model
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: import_fairseq_model
.. currentmodule:: torchaudio.models
WaveRNN
~~~~~~~
.. autoclass:: WaveRNN
.. automethod:: forward
.. autosummary::
:toctree: generated
:nosignatures:
.. automethod:: infer
~wav2vec2.utils.import_fairseq_model
~wav2vec2.utils.import_huggingface_model
@article{wavernn,
author = {Nal Kalchbrenner and
Erich Elsen and
Karen Simonyan and
Seb Noury and
Norman Casagrande and
Edward Lockhart and
Florian Stimberg and
A{\"{a}}ron van den Oord and
Sander Dieleman and
Koray Kavukcuoglu},
title = {Efficient Neural Audio Synthesis},
journal = {CoRR},
volume = {abs/1802.08435},
year = {2018},
url = {http://arxiv.org/abs/1802.08435},
eprinttype = {arXiv},
eprint = {1802.08435},
timestamp = {Mon, 13 Aug 2018 16:47:01 +0200},
biburl = {https://dblp.org/rec/journals/corr/abs-1802-08435.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
@misc{RESAMPLE,
author = {Julius O. Smith},
title = {Digital Audio Resampling Home Page "Theory of Ideal Bandlimited Interpolation" section},
......
......@@ -21,10 +21,12 @@ perform music separation
# 3. Collect output chunks and combine according to the way they have been
# overlapped.
#
# The `Hybrid Demucs <https://arxiv.org/pdf/2111.03600.pdf>`__ model is a developed version of the
# The Hybrid Demucs [`Défossez, 2021 <https://arxiv.org/abs/2111.03600>`__]
# model is a developed version of the
# `Demucs <https://github.com/facebookresearch/demucs>`__ model, a
# waveform based model which separates music into its
# respective sources, such as vocals, bass, and drums. Hybrid Demucs effectively uses spectrogram to learn
# respective sources, such as vocals, bass, and drums.
# Hybrid Demucs effectively uses spectrogram to learn
# through the frequency domain and also moves to time convolutions.
#
......@@ -81,7 +83,7 @@ except ModuleNotFoundError:
#
# Pre-trained model weights and related pipeline components are bundled as
# :py:func:`torchaudio.pipelines.HDEMUCS_HIGH_MUSDB_PLUS`. This is a
# HDemucs model trained on
# :py:class:`torchaudio.models.HDemucs` model trained on
# `MUSDB18-HQ <https://zenodo.org/record/3338373>`__ and additional
# internal extra training data.
# This specific model is suited for higher sample rates, around 44.1 kHZ
......
......@@ -299,8 +299,14 @@ class _HDecLayer(torch.nn.Module):
class HDemucs(torch.nn.Module):
r"""
Hybrid Demucs model from *Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`.
r"""Hybrid Demucs model from
*Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`.
See Also:
* :func:`~torchaudio.models.hdemucs_low`,
:func:`~torchaudio.models.hdemucs_medium`,
:func:`~torchaudio.models.hdemucs_high`: factory functions.
* :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
Args:
sources (List[str]): list of source names. List can contain the following source
......@@ -959,8 +965,7 @@ def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int =
def hdemucs_low(sources: List[str]) -> HDemucs:
r"""Builds low nfft (1024) version of HDemucs model. This version is suitable for lower sample rates, and bundles
parameters together to call valid nfft and depth values for a model structured for sample rates around 8 kHZ.
"""Builds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz.
Args:
sources (List[str]): See :py:func:`HDemucs`.
......@@ -974,8 +979,7 @@ def hdemucs_low(sources: List[str]) -> HDemucs:
def hdemucs_medium(sources: List[str]) -> HDemucs:
r"""Builds medium nfft (2048) version of HDemucs model. This version is suitable for medium sample rates,and bundles
parameters together to call valid nfft and depth values for a model structured for sample rates around 16-32 kHZ
r"""Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz.
.. note::
......@@ -994,9 +998,7 @@ def hdemucs_medium(sources: List[str]) -> HDemucs:
def hdemucs_high(sources: List[str]) -> HDemucs:
r"""Builds high nfft (4096) version of HDemucs model. This version is suitable for high/standard music sample rates,
and bundles parameters together to call valid nfft and depth values for a model structured for sample rates around
44.1-48 kHZ
r"""Builds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz.
Args:
sources (List[str]): See :py:func:`HDemucs`.
......
......@@ -213,7 +213,7 @@ class ConformerLayer(torch.nn.Module):
class Conformer(torch.nn.Module):
r"""Implements the Conformer architecture introduced in
r"""Conformer architecture introduced in
*Conformer: Convolution-augmented Transformer for Speech Recognition*
:cite:`gulati2020conformer`.
......
......@@ -160,10 +160,17 @@ class MaskGenerator(torch.nn.Module):
class ConvTasNet(torch.nn.Module):
"""Conv-TasNet: a fully-convolutional time-domain audio separation network
"""Conv-TasNet architecture introduced in
*Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
:cite:`Luo_2019`.
Note:
This implementation corresponds to the "non-causal" setting in the paper.
See Also:
* :func:`~torchaudio.models.conv_tasnet_base`: A factory function.
* :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
Args:
num_sources (int, optional): The number of sources to split.
enc_kernel_size (int, optional): The convolution kernel size of the encoder/decoder, <L>.
......@@ -174,9 +181,6 @@ class ConvTasNet(torch.nn.Module):
msk_num_layers (int, optional): The number of layers in one conv block of the mask generator, <X>.
msk_num_stacks (int, optional): The numbr of conv blocks of the mask generator, <R>.
msk_activate (str, optional): The activation function of the mask output (Default: ``sigmoid``).
Note:
This implementation corresponds to the "non-causal" setting in the paper.
"""
def __init__(
......@@ -302,9 +306,7 @@ class ConvTasNet(torch.nn.Module):
def conv_tasnet_base(num_sources: int = 2) -> ConvTasNet:
r"""Builds the non-causal version of ConvTasNet in
*Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation*
:cite:`Luo_2019`.
r"""Builds non-causal version of :class:`~torchaudio.models.ConvTasNet`.
The parameter settings follow the ones with the highest Si-SNR metirc score in the paper,
except the mask activation function is changed from "sigmoid" to "relu" for performance improvement.
......
......@@ -26,9 +26,8 @@ class FullyConnected(torch.nn.Module):
class DeepSpeech(torch.nn.Module):
"""
DeepSpeech model architecture from *Deep Speech: Scaling up end-to-end speech recognition*
:cite:`hannun2014deep`.
"""DeepSpeech architecture introduced in
*Deep Speech: Scaling up end-to-end speech recognition* :cite:`hannun2014deep`.
Args:
n_feature: Number of input features
......
......@@ -804,10 +804,15 @@ class _EmformerImpl(torch.nn.Module):
class Emformer(_EmformerImpl):
r"""Implements the Emformer architecture introduced in
r"""Emformer architecture introduced in
*Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition*
:cite:`shi2021emformer`.
See Also:
* :func:`~torchaudio.models.emformer_rnnt_model`,
:func:`~torchaudio.models.emformer_rnnt_base`: factory functions.
* :class:`torchaudio.pipelines.RNNTBundle`: ASR pipelines with pretrained model.
Args:
input_dim (int): input dimension.
num_heads (int): number of attention heads in each Emformer layer.
......
......@@ -456,7 +456,11 @@ class RNNT(torch.nn.Module):
Recurrent neural network transducer (RNN-T) model.
Note:
To build the model, please use one of the factory functions.
To build the model, please use one of the factory functions,
:py:func:`emformer_rnnt_model` or :py:func:`emformer_rnnt_base`.
See Also:
:class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pre-trained models.
Args:
transcriber (torch.nn.Module): transcription network.
......@@ -706,7 +710,7 @@ def emformer_rnnt_model(
lstm_layer_norm_epsilon: float,
lstm_dropout: float,
) -> RNNT:
r"""Builds Emformer-based recurrent neural network transducer (RNN-T) model.
r"""Builds Emformer-based :class:`~torchaudio.models.RNNT`.
Note:
For non-streaming inference, the expectation is for `transcribe` to be called on input
......@@ -779,7 +783,7 @@ def emformer_rnnt_model(
def emformer_rnnt_base(num_symbols: int) -> RNNT:
r"""Builds basic version of Emformer RNN-T model.
r"""Builds basic version of Emformer-based :class:`~torchaudio.models.RNNT`.
Args:
num_symbols (int): The size of target token lexicon.
......
......@@ -75,6 +75,9 @@ def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
class RNNTBeamSearch(torch.nn.Module):
r"""Beam search decoder for RNN-T model.
See Also:
* :class:`torchaudio.pipelines.RNNTBundle`: ASR pipeline with pretrained model.
Args:
model (RNNT): RNN-T model to use.
blank (int): index of blank token in vocabulary.
......
......@@ -867,12 +867,12 @@ class _Decoder(nn.Module):
class Tacotron2(nn.Module):
r"""Tacotron2 model based on the implementation from
`Nvidia <https://github.com/NVIDIA/DeepLearningExamples/>`_.
r"""Tacotron2 model from *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
:cite:`shen2018natural` based on the implementation from
`Nvidia Deep Learning Examples <https://github.com/NVIDIA/DeepLearningExamples/>`_.
The original implementation was introduced in
*Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
:cite:`shen2018natural`.
See Also:
* :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
Args:
mask_padding (bool, optional): Use mask padding (Default: ``False``).
......
......@@ -9,7 +9,8 @@ class Wav2Letter(nn.Module):
r"""Wav2Letter model architecture from *Wav2Letter: an End-to-End ConvNet-based Speech
Recognition System* :cite:`collobert2016wav2letter`.
:math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}`
See Also:
* `Training example <https://github.com/pytorch/audio/tree/release/0.12/examples/pipeline_wav2letter>`__
Args:
num_classes (int, optional): Number of classes to be classified. (Default: ``40``)
......@@ -19,7 +20,7 @@ class Wav2Letter(nn.Module):
"""
def __init__(self, num_classes: int = 40, input_type: str = "waveform", num_features: int = 1) -> None:
super(Wav2Letter, self).__init__()
super().__init__()
acoustic_num_features = 250 if input_type == "waveform" else num_features
acoustic_model = nn.Sequential(
......
......@@ -8,10 +8,17 @@ from . import components
class Wav2Vec2Model(Module):
"""Encoder model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`.
"""Acoustic model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`.
Note:
To build the model, please use one of the factory functions.
:py:func:`wav2vec2_model`, :py:func:`wav2vec2_base`, :py:func:`wav2vec2_large`,
:py:func:`wav2vec2_large_lv60k`, :py:func:`hubert_base`, :py:func:`hubert_large`,
and :py:func:`hubert_xlarge`.
See Also:
* :class:`torchaudio.pipelines.Wav2Vec2Bundle`: Pretrained models (without fine-tuning)
* :class:`torchaudio.pipelines.Wav2Vec2ASRBundle`: ASR pipelines with pretrained models.
Args:
feature_extractor (torch.nn.Module):
......@@ -116,11 +123,18 @@ class Wav2Vec2Model(Module):
class HuBERTPretrainModel(Module):
"""HuBERT pre-train model for training from scratch.
"""HuBERTPretrainModel()
HuBERT model used for pretraining in *HuBERT* :cite:`hsu2021hubert`.
Note:
To build the model, please use one of the factory functions in
`[hubert_pretrain_base, hubert_pretrain_large, hubert_pretrain_xlarge]`.
To build the model, please use one of the factory functions,
:py:func:`hubert_pretrain_base`, :py:func:`hubert_pretrain_large`
or :py:func:`hubert_pretrain_xlarge`.
See Also:
`HuBERT Pre-training and Fine-tuning Examples
<https://github.com/pytorch/audio/tree/release/0.12/examples/hubert>`__
Args:
feature_extractor (torch.nn.Module):
......@@ -235,7 +249,7 @@ def wav2vec2_model(
encoder_layer_drop: float,
aux_num_out: Optional[int],
) -> Wav2Vec2Model:
"""Build a custom Wav2Vec2Model
"""Builds custom :class:`~torchaudio.models.Wav2Vec2Model`.
Note:
The "feature extractor" below corresponds to
......@@ -391,7 +405,7 @@ def wav2vec2_base(
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build Wav2Vec2Model with "base" architecture from *wav2vec 2.0* :cite:`baevski2020wav2vec`
"""Builds "base" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
Args:
encoder_projection_dropout (float):
......@@ -439,7 +453,7 @@ def wav2vec2_large(
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build Wav2Vec2Model with "large" architecture from *wav2vec 2.0* :cite:`baevski2020wav2vec`
"""Builds "large" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
Args:
encoder_projection_dropout (float):
......@@ -487,7 +501,7 @@ def wav2vec2_large_lv60k(
encoder_layer_drop: float = 0.1,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build Wav2Vec2Model with "large lv-60k" architecture from *wav2vec 2.0* :cite:`baevski2020wav2vec`
"""Builds "large lv-60k" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec`
Args:
encoder_projection_dropout (float):
......@@ -535,7 +549,7 @@ def hubert_base(
encoder_layer_drop: float = 0.05,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build HuBERT model with "base" architecture from *HuBERT* :cite:`hsu2021hubert`
"""Builds "base" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
Args:
encoder_projection_dropout (float):
......@@ -583,7 +597,7 @@ def hubert_large(
encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build HuBERT model with "large" architecture from *HuBERT* :cite:`hsu2021hubert`
"""Builds "large" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
Args:
encoder_projection_dropout (float):
......@@ -631,7 +645,7 @@ def hubert_xlarge(
encoder_layer_drop: float = 0.0,
aux_num_out: Optional[int] = None,
) -> Wav2Vec2Model:
"""Build HuBERT model with "extra large" architecture from *HuBERT* :cite:`hsu2021hubert`
"""Builds "extra large" :class:`HuBERT <torchaudio.models.Wav2Vec2Model>` from *HuBERT* :cite:`hsu2021hubert`
Args:
encoder_projection_dropout (float):
......@@ -705,7 +719,7 @@ def hubert_pretrain_model(
final_dim: int,
feature_grad_mult: Optional[float],
) -> HuBERTPretrainModel:
"""Build a custom HuBERTPretrainModel for training from scratch
"""Builds custom :class:`HuBERTPretrainModel` for training from scratch
Note:
The "feature extractor" below corresponds to
......@@ -973,7 +987,7 @@ def hubert_pretrain_base(
feature_grad_mult: Optional[float] = 0.1,
num_classes: int = 100,
) -> HuBERTPretrainModel:
"""Build HuBERTPretrainModel model with "base" architecture from *HuBERT* :cite:`hsu2021hubert`
"""Builds "base" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
Args:
encoder_projection_dropout (float):
......@@ -1048,7 +1062,7 @@ def hubert_pretrain_large(
mask_channel_length: int = 10,
feature_grad_mult: Optional[float] = None,
) -> HuBERTPretrainModel:
"""Build HuBERTPretrainModel model for pre-training with "large" architecture from *HuBERT* :cite:`hsu2021hubert`
"""Builds "large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
Args:
encoder_projection_dropout (float):
......@@ -1121,7 +1135,7 @@ def hubert_pretrain_xlarge(
mask_channel_length: int = 10,
feature_grad_mult: Optional[float] = None,
) -> HuBERTPretrainModel:
"""Build HuBERTPretrainModel model for pre-training with "extra large" architecture from *HuBERT* :cite:`hsu2021hubert`
"""Builds "extra large" :class:`HuBERTPretrainModel` from *HuBERT* :cite:`hsu2021hubert` for pretraining.
Args:
encoder_projection_dropout (float):
......
......@@ -125,7 +125,8 @@ def _convert_state_dict(state_dict):
def import_fairseq_model(original: Module) -> Wav2Vec2Model:
"""Build Wav2Vec2Model from the corresponding model object of `fairseq`_.
"""Builds :class:`Wav2Vec2Model` from the corresponding model object of
`fairseq <https://github.com/pytorch/fairseq>`_.
Args:
original (torch.nn.Module):
......@@ -171,8 +172,6 @@ def import_fairseq_model(original: Module) -> Wav2Vec2Model:
>>> mask = torch.zeros_like(waveform)
>>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1)
>>> torch.testing.assert_allclose(emission, reference)
.. _fairseq: https://github.com/pytorch/fairseq
"""
class_ = original.__class__.__name__
if class_ == "Wav2Vec2Model":
......
......@@ -48,7 +48,8 @@ def _build(config, original):
def import_huggingface_model(original: Module) -> Wav2Vec2Model:
"""Build Wav2Vec2Model from the corresponding model object of Hugging Face's `Transformers`_.
"""Builds :class:`Wav2Vec2Model` from the corresponding model object of
`Transformers <https://huggingface.co/transformers/>`_.
Args:
original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``.
......@@ -64,8 +65,6 @@ def import_huggingface_model(original: Module) -> Wav2Vec2Model:
>>>
>>> waveforms, _ = torchaudio.load("audio.wav")
>>> logits, _ = model(waveforms)
.. _Transformers: https://huggingface.co/transformers/
"""
_LG.info("Importing model.")
_LG.info("Loading model configuration.")
......
......@@ -197,12 +197,17 @@ class UpsampleNetwork(nn.Module):
class WaveRNN(nn.Module):
r"""WaveRNN model based on the implementation from `fatchord <https://github.com/fatchord/WaveRNN>`_.
r"""WaveRNN model from *Efficient Neural Audio Synthesis* :cite:`wavernn`
based on the implementation from `fatchord/WaveRNN <https://github.com/fatchord/WaveRNN>`_.
The original implementation was introduced in *Efficient Neural Audio Synthesis*
:cite:`kalchbrenner2018efficient`. The input channels of waveform and spectrogram have to be 1.
The product of `upsample_scales` must equal `hop_length`.
See Also:
* `Training example <https://github.com/pytorch/audio/tree/release/0.12/examples/pipeline_wavernn>`__
* :class:`torchaudio.pipelines.Tacotron2TTSBundle`: TTS pipeline with pretrained model.
Args:
upsample_scales: the list of upsample scales.
n_classes: the number of output classes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment