Commit 64c7e065 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Update CTC Hypothesis docs (#2117)

Summary:
add documentaion for CTC decoder `Hypothesis` and include it in docs

Pull Request resolved: https://github.com/pytorch/audio/pull/2117

Reviewed By: mthrok

Differential Revision: D33370381

Pulled By: carolineechen

fbshipit-source-id: cf6501a499e5303cda0410f733f0fab4e1c39aff
parent 9f14fa63
......@@ -16,6 +16,11 @@ KenLMLexiconDecoder
.. automethod:: idxs_to_tokens
Hypothesis
~~~~~~~~~~
.. autoclass:: Hypothesis
Factory Function
----------------
......
......@@ -2,7 +2,7 @@ import torchaudio
try:
torchaudio._extension._load_lib("libtorchaudio_decoder")
from .ctc_decoder import KenLMLexiconDecoder, kenlm_lexicon_decoder
from .ctc_decoder import Hypothesis, KenLMLexiconDecoder, kenlm_lexicon_decoder
except ImportError as err:
raise ImportError(
"flashlight decoder bindings are required to use this functionality. "
......@@ -11,6 +11,7 @@ except ImportError as err:
__all__ = [
"Hypothesis",
"KenLMLexiconDecoder",
"kenlm_lexicon_decoder",
]
import itertools as it
from collections import namedtuple
from typing import Dict
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union, NamedTuple
import torch
from torchaudio._torchaudio_decoder import (
......@@ -17,10 +15,19 @@ from torchaudio._torchaudio_decoder import (
)
__all__ = ["KenLMLexiconDecoder", "kenlm_lexicon_decoder"]
__all__ = ["Hypothesis", "KenLMLexiconDecoder", "kenlm_lexicon_decoder"]
Hypothesis = namedtuple("Hypothesis", ["tokens", "words", "score"])
class Hypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func`KenLMLexiconDecoder`.
:ivar torch.LongTensor tokens: Predicted sequence of token IDs
:ivar List[str] words: List of predicted words
:ivar float score: Score corresponding to hypothesis
"""
tokens: torch.LongTensor
words: List[str]
score: float
class KenLMLexiconDecoder:
......@@ -99,7 +106,10 @@ class KenLMLexiconDecoder:
return torch.LongTensor(list(idxs))
def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> List[List[Hypothesis]]:
"""
# Overriding the signature so that the return type is correct on Sphinx
"""__call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> \
List[List[torchaudio.prototype.ctc_decoder.Hypothesis]]
Args:
emissions (torch.FloatTensor): tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model
......@@ -109,11 +119,6 @@ class KenLMLexiconDecoder:
Returns:
List[List[Hypothesis]]:
List of sorted best hypotheses for each audio sequence in the batch.
Each hypothesis is named tuple with the following fields:
tokens: torch.LongTensor of raw token IDs
score: hypothesis score
words: list of decoded words
"""
assert emissions.dtype == torch.float32
......@@ -132,9 +137,9 @@ class KenLMLexiconDecoder:
hypos.append(
[
Hypothesis(
self._get_tokens(result.tokens), # token ids
[self.word_dict.get_entry(x) for x in result.words if x >= 0], # words
result.score, # score
tokens=self._get_tokens(result.tokens), # token ids
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0], # words
score=result.score, # score
)
for result in nbest_results
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment