Commit 64c7e065 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Update CTC Hypothesis docs (#2117)

Summary:
add documentaion for CTC decoder `Hypothesis` and include it in docs

Pull Request resolved: https://github.com/pytorch/audio/pull/2117

Reviewed By: mthrok

Differential Revision: D33370381

Pulled By: carolineechen

fbshipit-source-id: cf6501a499e5303cda0410f733f0fab4e1c39aff
parent 9f14fa63
...@@ -16,6 +16,11 @@ KenLMLexiconDecoder ...@@ -16,6 +16,11 @@ KenLMLexiconDecoder
.. automethod:: idxs_to_tokens .. automethod:: idxs_to_tokens
Hypothesis
~~~~~~~~~~
.. autoclass:: Hypothesis
Factory Function Factory Function
---------------- ----------------
......
...@@ -2,7 +2,7 @@ import torchaudio ...@@ -2,7 +2,7 @@ import torchaudio
try: try:
torchaudio._extension._load_lib("libtorchaudio_decoder") torchaudio._extension._load_lib("libtorchaudio_decoder")
from .ctc_decoder import KenLMLexiconDecoder, kenlm_lexicon_decoder from .ctc_decoder import Hypothesis, KenLMLexiconDecoder, kenlm_lexicon_decoder
except ImportError as err: except ImportError as err:
raise ImportError( raise ImportError(
"flashlight decoder bindings are required to use this functionality. " "flashlight decoder bindings are required to use this functionality. "
...@@ -11,6 +11,7 @@ except ImportError as err: ...@@ -11,6 +11,7 @@ except ImportError as err:
__all__ = [ __all__ = [
"Hypothesis",
"KenLMLexiconDecoder", "KenLMLexiconDecoder",
"kenlm_lexicon_decoder", "kenlm_lexicon_decoder",
] ]
import itertools as it import itertools as it
from collections import namedtuple from typing import Dict, List, Optional, Union, NamedTuple
from typing import Dict
from typing import List, Optional, Union
import torch import torch
from torchaudio._torchaudio_decoder import ( from torchaudio._torchaudio_decoder import (
...@@ -17,10 +15,19 @@ from torchaudio._torchaudio_decoder import ( ...@@ -17,10 +15,19 @@ from torchaudio._torchaudio_decoder import (
) )
__all__ = ["KenLMLexiconDecoder", "kenlm_lexicon_decoder"] __all__ = ["Hypothesis", "KenLMLexiconDecoder", "kenlm_lexicon_decoder"]
Hypothesis = namedtuple("Hypothesis", ["tokens", "words", "score"]) class Hypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func`KenLMLexiconDecoder`.
:ivar torch.LongTensor tokens: Predicted sequence of token IDs
:ivar List[str] words: List of predicted words
:ivar float score: Score corresponding to hypothesis
"""
tokens: torch.LongTensor
words: List[str]
score: float
class KenLMLexiconDecoder: class KenLMLexiconDecoder:
...@@ -99,7 +106,10 @@ class KenLMLexiconDecoder: ...@@ -99,7 +106,10 @@ class KenLMLexiconDecoder:
return torch.LongTensor(list(idxs)) return torch.LongTensor(list(idxs))
def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> List[List[Hypothesis]]: def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> List[List[Hypothesis]]:
""" # Overriding the signature so that the return type is correct on Sphinx
"""__call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> \
List[List[torchaudio.prototype.ctc_decoder.Hypothesis]]
Args: Args:
emissions (torch.FloatTensor): tensor of shape `(batch, frame, num_tokens)` storing sequences of emissions (torch.FloatTensor): tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model probability distribution over labels; output of acoustic model
...@@ -109,11 +119,6 @@ class KenLMLexiconDecoder: ...@@ -109,11 +119,6 @@ class KenLMLexiconDecoder:
Returns: Returns:
List[List[Hypothesis]]: List[List[Hypothesis]]:
List of sorted best hypotheses for each audio sequence in the batch. List of sorted best hypotheses for each audio sequence in the batch.
Each hypothesis is named tuple with the following fields:
tokens: torch.LongTensor of raw token IDs
score: hypothesis score
words: list of decoded words
""" """
assert emissions.dtype == torch.float32 assert emissions.dtype == torch.float32
...@@ -132,9 +137,9 @@ class KenLMLexiconDecoder: ...@@ -132,9 +137,9 @@ class KenLMLexiconDecoder:
hypos.append( hypos.append(
[ [
Hypothesis( Hypothesis(
self._get_tokens(result.tokens), # token ids tokens=self._get_tokens(result.tokens), # token ids
[self.word_dict.get_entry(x) for x in result.words if x >= 0], # words words=[self.word_dict.get_entry(x) for x in result.words if x >= 0], # words
result.score, # score score=result.score, # score
) )
for result in nbest_results for result in nbest_results
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment