Commit c89ab0c6 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Adopt `:autosummary:` in `torchaudio.models.decoder` module doc (#2684)

Summary:
* Adopts `:autosummary:` in decoder module doc
* Hide the constructor signature of `CTCDecoder` as `ctc_decoder` function is the one client code is supposed to be using.
* Introduce `children` property to `CTCDecoderLMState` otherwise it does not show up in the doc.

https://output.circle-artifacts.com/output/job/7aac5eb9-7d2d-4f63-bcdf-83a6f40b4e5a/artifacts/0/docs/models.decoder.html

<img width="748" alt="Screen Shot 2022-09-16 at 5 23 22 PM" src="https://user-images.githubusercontent.com/855818/190592409-0c2ec8a4-d2cf-4d76-a965-8a570faaeb1a.png">

https://output.circle-artifacts.com/output/job/7aac5eb9-7d2d-4f63-bcdf-83a6f40b4e5a/artifacts/0/docs/generated/torchaudio.models.decoder.CTCDecoder.html#torchaudio.models.decoder.CTCDecoder

<img width="723" alt="Screen Shot 2022-09-16 at 5 23 53 PM" src="https://user-images.githubusercontent.com/855818/190592501-3fad1e07-ae3e-44f5-93be-f33181025390.png">

Pull Request resolved: https://github.com/pytorch/audio/pull/2684

Reviewed By: carolineechen

Differential Revision: D39574272

Pulled By: mthrok

fbshipit-source-id: d977660bd46f5cf98c535adbf2735be896b28773
parent f50a9286
..
autogenerated from source/_templates/autosummary/ctc_decoder_class.rst
{#
################################################################################
# autosummary template for CTCDecoder
# Since the class has multiple methods and support structure.
# we want to have them show up in the table of contents.
# The default class template does not do this, so we use custom one here.
################################################################################
#}
{{ name | underline }}
{%- if name != "CTCDecoder" %}
.. autofunction:: {{fullname}}
{%- else %}
.. autoclass:: {{ fullname }}()
Methods
=======
{%- for item in methods %}
{{ item | underline("-") }}
.. container:: py attribute
.. automethod:: {{[fullname, item] | join('.')}}
{%- endfor %}
Support Structures
==================
{%- for item in ["CTCDecoderLM", "CTCDecoderLMState", "CTCHypothesis"] %}
{{ item | underline("-") }}
.. autoclass:: torchaudio.models.decoder.{{item}}
:members:
{%- endfor %}
{%- endif %}
.. role:: hidden .. py:module:: torchaudio.models.decoder
:class: hidden-section
torchaudio.models.decoder torchaudio.models.decoder
========================= =========================
.. currentmodule:: torchaudio.models.decoder .. currentmodule:: torchaudio.models.decoder
.. py:module:: torchaudio.models.decoder CTC Decoder
-----------
Decoder Class
-------------
CTCDecoder
~~~~~~~~~~
.. autoclass:: CTCDecoder
.. automethod:: __call__
.. automethod:: idxs_to_tokens
CTCDecoderLM
~~~~~~~~~~~~
.. autoclass:: CTCDecoderLM
.. automethod:: start
.. automethod:: score
.. automethod:: finish
CTCDecoderLMState
~~~~~~~~~~~~~~~~~
.. autoclass:: CTCDecoderLMState
:members: children
.. automethod:: child
.. automethod:: compare
CTCHypothesis
~~~~~~~~~~~~~
.. autoclass:: CTCHypothesis
Factory Function
----------------
ctc_decoder
~~~~~~~~~~~
.. autoclass:: ctc_decoder .. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/ctc_decoder_class.rst
Utility Function CTCDecoder
---------------- ctc_decoder
download_pretrained_files
download_pretrained_files .. rubric:: Tutorials using CTC Decoder
~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: download_pretrained_files .. minigallery:: torchaudio.models.decoder.CTCDecoder
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import itertools as it import itertools as it
from abc import abstractmethod from abc import abstractmethod
from collections import namedtuple from collections import namedtuple
from typing import Dict, List, NamedTuple, Optional, Union from typing import Dict, List, NamedTuple, Optional, Tuple, Union
import torch import torch
import torchaudio import torchaudio
...@@ -96,33 +96,35 @@ def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word): ...@@ -96,33 +96,35 @@ def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word):
class CTCHypothesis(NamedTuple): class CTCHypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func:`CTCDecoder`. r"""Represents hypothesis generated by CTC beam search decoder :class:`CTCDecoder`."""
tokens: torch.LongTensor
"""Predicted sequence of token IDs. Shape `(L, )`, where `L` is the length of the output sequence"""
words: List[str]
"""List of predicted words.
Note: Note:
The ``words`` field is only applicable if a lexicon is provided to the decoder. If This attribute is only applicable if a lexicon is provided to the decoder. If
decoding without a lexicon, it will be blank. Please refer to ``tokens`` and decoding without a lexicon, it will be blank. Please refer to :attr:`tokens` and
:py:func:`idxs_to_tokens <torchaudio.models.decoder.CTCDecoder.idxs_to_tokens>` instead. :func:`~torchaudio.models.decoder.CTCDecoder.idxs_to_tokens` instead.
:ivar torch.LongTensor tokens: Predicted sequence of token IDs. Shape `(L, )`, where
`L` is the length of the output sequence
:ivar List[str] words: List of predicted words
:ivar float score: Score corresponding to hypothesis
:ivar torch.IntTensor timesteps: Timesteps corresponding to the tokens. Shape `(L, )`,
where `L` is the length of the output sequence
""" """
tokens: torch.LongTensor
words: List[str]
score: float score: float
"""Score corresponding to hypothesis"""
timesteps: torch.IntTensor timesteps: torch.IntTensor
"""Timesteps corresponding to the tokens. Shape `(L, )`, where `L` is the length of the output sequence"""
class CTCDecoderLMState(_LMState): class CTCDecoderLMState(_LMState):
"""Language model state. """Language model state."""
:ivar Dict[int] children: Map of indices to LM states @property
""" def children(self) -> Dict[int, CTCDecoderLMState]:
"""Map of indices to LM states"""
return super().children
def child(self, usr_index: int): def child(self, usr_index: int) -> CTCDecoderLMState:
"""Returns child corresponding to usr_index, or creates and returns a new state if input index """Returns child corresponding to usr_index, or creates and returns a new state if input index
is not found. is not found.
...@@ -134,7 +136,7 @@ class CTCDecoderLMState(_LMState): ...@@ -134,7 +136,7 @@ class CTCDecoderLMState(_LMState):
""" """
return super().child(usr_index) return super().child(usr_index)
def compare(self, state: CTCDecoderLMState): def compare(self, state: CTCDecoderLMState) -> CTCDecoderLMState:
"""Compare two language model states. """Compare two language model states.
Args: Args:
...@@ -150,7 +152,7 @@ class CTCDecoderLM(_LM): ...@@ -150,7 +152,7 @@ class CTCDecoderLM(_LM):
"""Language model base class for creating custom language models to use with the decoder.""" """Language model base class for creating custom language models to use with the decoder."""
@abstractmethod @abstractmethod
def start(self, start_with_nothing: bool): def start(self, start_with_nothing: bool) -> CTCDecoderLMState:
"""Initialize or reset the language model. """Initialize or reset the language model.
Args: Args:
...@@ -162,7 +164,7 @@ class CTCDecoderLM(_LM): ...@@ -162,7 +164,7 @@ class CTCDecoderLM(_LM):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def score(self, state: CTCDecoderLMState, usr_token_idx: int): def score(self, state: CTCDecoderLMState, usr_token_idx: int) -> Tuple[CTCDecoderLMState, float]:
"""Evaluate the language model based on the current LM state and new word. """Evaluate the language model based on the current LM state and new word.
Args: Args:
...@@ -170,14 +172,14 @@ class CTCDecoderLM(_LM): ...@@ -170,14 +172,14 @@ class CTCDecoderLM(_LM):
usr_token_idx (int): index of the word usr_token_idx (int): index of the word
Returns: Returns:
Tuple[CTCDecoderLMState, float] (CTCDecoderLMState, float)
CTCDecoderLMState: new LM state CTCDecoderLMState: new LM state
float: score float: score
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def finish(self, state: CTCDecoderLMState): def finish(self, state: CTCDecoderLMState) -> Tuple[CTCDecoderLMState, float]:
"""Evaluate end for language model based on current LM state. """Evaluate end for language model based on current LM state.
Args: Args:
...@@ -194,24 +196,12 @@ class CTCDecoderLM(_LM): ...@@ -194,24 +196,12 @@ class CTCDecoderLM(_LM):
class CTCDecoder: class CTCDecoder:
""" """CTC beam search decoder from *Flashlight* :cite:`kahn2022flashlight`.
.. devices:: CPU
CTC beam search decoder from *Flashlight* :cite:`kahn2022flashlight`. .. devices:: CPU
Note: Note:
To build the decoder, please use the factory function :py:func:`ctc_decoder`. To build the decoder, please use the factory function :func:`ctc_decoder`.
Args:
nbest (int): number of best decodings to return
lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
word_dict (_Dictionary): dictionary of words
tokens_dict (_Dictionary): dictionary of tokens
lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported
decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions): parameters used for beam search decoding
blank_token (str): token corresopnding to blank
sil_token (str): token corresponding to silence
unk_word (str): word corresponding to unknown
""" """
def __init__( def __init__(
...@@ -226,6 +216,20 @@ class CTCDecoder: ...@@ -226,6 +216,20 @@ class CTCDecoder:
sil_token: str, sil_token: str,
unk_word: str, unk_word: str,
) -> None: ) -> None:
"""
Args:
nbest (int): number of best decodings to return
lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
word_dict (_Dictionary): dictionary of words
tokens_dict (_Dictionary): dictionary of tokens
lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported
decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions):
parameters used for beam search decoding
blank_token (str): token corresopnding to blank
sil_token (str): token corresponding to silence
unk_word (str): word corresponding to unknown
"""
self.nbest = nbest self.nbest = nbest
self.word_dict = word_dict self.word_dict = word_dict
self.tokens_dict = tokens_dict self.tokens_dict = tokens_dict
...@@ -348,8 +352,7 @@ def ctc_decoder( ...@@ -348,8 +352,7 @@ def ctc_decoder(
sil_token: str = "|", sil_token: str = "|",
unk_word: str = "<unk>", unk_word: str = "<unk>",
) -> CTCDecoder: ) -> CTCDecoder:
""" """Builds an instance of :class:`CTCDecoder`.
Builds CTC beam search decoder from *Flashlight* :cite:`kahn2022flashlight`.
Args: Args:
lexicon (str or None): lexicon file containing the possible words and corresponding spellings. lexicon (str or None): lexicon file containing the possible words and corresponding spellings.
...@@ -455,20 +458,19 @@ def _get_filenames(model: str) -> _PretrainedFiles: ...@@ -455,20 +458,19 @@ def _get_filenames(model: str) -> _PretrainedFiles:
def download_pretrained_files(model: str) -> _PretrainedFiles: def download_pretrained_files(model: str) -> _PretrainedFiles:
""" """
Retrieves pretrained data files used for CTC decoder. Retrieves pretrained data files used for :func:`ctc_decoder`.
Args: Args:
model (str): pretrained language model to download. model (str): pretrained language model to download.
Options: ["librispeech-3-gram", "librispeech-4-gram", "librispeech"] Valid values are: ``"librispeech-3-gram"``, ``"librispeech-4-gram"`` and ``"librispeech"``.
Returns: Returns:
Object with the following attributes Object with the following attributes
lm:
path corresponding to downloaded language model, or `None` if the model is not associated with an lm * ``lm``: path corresponding to downloaded language model,
lexicon: or ``None`` if the model is not associated with an lm
path corresponding to downloaded lexicon file * ``lexicon``: path corresponding to downloaded lexicon file
tokens: * ``tokens``: path corresponding to downloaded tokens file
path corresponding to downloaded tokens file
""" """
files = _get_filenames(model) files = _get_filenames(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment