Commit 9f2bbf6c authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add Decoder LM Docs (#2658)

Summary:
modifications to ctc decoder LM docstrings on top of https://github.com/pytorch/audio/issues/2657

Pull Request resolved: https://github.com/pytorch/audio/pull/2658

Reviewed By: mthrok

Differential Revision: D39468921

Pulled By: carolineechen

fbshipit-source-id: c5497cc2fa22fb98a304d037e27c91bf68a9ad6a
parent 60868748
......@@ -376,7 +376,9 @@ def fix_aliases():
fix_module_path(module, attribute)
if importlib.util.find_spec("torchaudio.flashlight_lib_text_decoder") is not None:
fix_module_path("torchaudio.models.decoder", ["CTCHypothesis"])
for class_ in ["CTCHypothesis", "CTCDecoderLM", "CTCDecoderLMState"]:
fix_module_path("torchaudio.models.decoder", [class_])
def setup(app):
......
......@@ -21,6 +21,29 @@ CTCDecoder
.. automethod:: idxs_to_tokens
CTCDecoderLM
~~~~~~~~~~~~
.. autoclass:: CTCDecoderLM
.. automethod:: start
.. automethod:: score
.. automethod:: finish
CTCDecoderLMState
~~~~~~~~~~~~~~~~~
.. autoclass:: CTCDecoderLMState
:members: children
.. automethod:: child
.. automethod:: compare
CTCHypothesis
~~~~~~~~~~~~~
......
from __future__ import annotations
import itertools as it
from abc import abstractmethod
from collections import namedtuple
from typing import Dict, List, NamedTuple, Optional, Union
......@@ -17,8 +20,8 @@ try:
LexiconDecoderOptions as _LexiconDecoderOptions,
LexiconFreeDecoder as _LexiconFreeDecoder,
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
LM as CTCDecoderLM,
LMState as CTCDecoderLMState,
LM as _LM,
LMState as _LMState,
SmearingMode as _SmearingMode,
Trie as _Trie,
ZeroLM as _ZeroLM,
......@@ -37,8 +40,8 @@ except Exception:
LexiconDecoderOptions as _LexiconDecoderOptions,
LexiconFreeDecoder as _LexiconFreeDecoder,
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
LM as CTCDecoderLM,
LMState as CTCDecoderLMState,
LM as _LM,
LMState as _LMState,
SmearingMode as _SmearingMode,
Trie as _Trie,
ZeroLM as _ZeroLM,
......@@ -113,6 +116,83 @@ class CTCHypothesis(NamedTuple):
timesteps: torch.IntTensor
class CTCDecoderLMState(_LMState):
"""Language model state.
:ivar Dict[int] children: Map of indices to LM states
"""
def child(self, usr_index: int):
"""Returns child corresponding to usr_index, or creates and returns a new state if input index
is not found.
Args:
usr_index (int): index corresponding to child state
Returns:
CTCDecoderLMState: child state corresponding to usr_index
"""
return super().child(usr_index)
def compare(self, state: CTCDecoderLMState):
"""Compare two language model states.
Args:
state (CTCDecoderLMState): LM state to compare against
Returns:
int: 0 if the states are the same, -1 if self is less, +1 if self is greater.
"""
pass
class CTCDecoderLM(_LM):
"""Language model base class for creating custom language models to use with the decoder."""
@abstractmethod
def start(self, start_with_nothing: bool):
"""Initialize or reset the language model.
Args:
start_with_nothing (bool): whether or not to start sentence with sil token.
Returns:
CTCDecoderLMState: starting state
"""
raise NotImplementedError
@abstractmethod
def score(self, state: CTCDecoderLMState, usr_token_idx: int):
"""Evaluate the language model based on the current LM state and new word.
Args:
state (CTCDecoderLMState): current LM state
usr_token_idx (int): index of the word
Returns:
Tuple[CTCDecoderLMState, float]
CTCDecoderLMState: new LM state
float: score
"""
raise NotImplementedError
@abstractmethod
def finish(self, state: CTCDecoderLMState):
"""Evaluate end for language model based on current LM state.
Args:
state (CTCDecoderLMState): current LM state
Returns:
(CTCDecoderLMState, float)
CTCDecoderLMState:
new LM state
float:
score
"""
raise NotImplementedError
class CTCDecoder:
"""
.. devices:: CPU
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment