"...text-generation-inference.git" did not exist on "3ed4c0f33fee281fbdc276e208574e22821818d9"
Commit 9f2bbf6c authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add Decoder LM Docs (#2658)

Summary:
modifications to ctc decoder LM docstrings on top of https://github.com/pytorch/audio/issues/2657

Pull Request resolved: https://github.com/pytorch/audio/pull/2658

Reviewed By: mthrok

Differential Revision: D39468921

Pulled By: carolineechen

fbshipit-source-id: c5497cc2fa22fb98a304d037e27c91bf68a9ad6a
parent 60868748
...@@ -376,7 +376,9 @@ def fix_aliases(): ...@@ -376,7 +376,9 @@ def fix_aliases():
fix_module_path(module, attribute) fix_module_path(module, attribute)
if importlib.util.find_spec("torchaudio.flashlight_lib_text_decoder") is not None: if importlib.util.find_spec("torchaudio.flashlight_lib_text_decoder") is not None:
fix_module_path("torchaudio.models.decoder", ["CTCHypothesis"])
for class_ in ["CTCHypothesis", "CTCDecoderLM", "CTCDecoderLMState"]:
fix_module_path("torchaudio.models.decoder", [class_])
def setup(app): def setup(app):
......
...@@ -21,6 +21,29 @@ CTCDecoder ...@@ -21,6 +21,29 @@ CTCDecoder
.. automethod:: idxs_to_tokens .. automethod:: idxs_to_tokens
CTCDecoderLM
~~~~~~~~~~~~
.. autoclass:: CTCDecoderLM
.. automethod:: start
.. automethod:: score
.. automethod:: finish
CTCDecoderLMState
~~~~~~~~~~~~~~~~~
.. autoclass:: CTCDecoderLMState
:members: children
.. automethod:: child
.. automethod:: compare
CTCHypothesis CTCHypothesis
~~~~~~~~~~~~~ ~~~~~~~~~~~~~
......
from __future__ import annotations
import itertools as it import itertools as it
from abc import abstractmethod
from collections import namedtuple from collections import namedtuple
from typing import Dict, List, NamedTuple, Optional, Union from typing import Dict, List, NamedTuple, Optional, Union
...@@ -17,8 +20,8 @@ try: ...@@ -17,8 +20,8 @@ try:
LexiconDecoderOptions as _LexiconDecoderOptions, LexiconDecoderOptions as _LexiconDecoderOptions,
LexiconFreeDecoder as _LexiconFreeDecoder, LexiconFreeDecoder as _LexiconFreeDecoder,
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions, LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
LM as CTCDecoderLM, LM as _LM,
LMState as CTCDecoderLMState, LMState as _LMState,
SmearingMode as _SmearingMode, SmearingMode as _SmearingMode,
Trie as _Trie, Trie as _Trie,
ZeroLM as _ZeroLM, ZeroLM as _ZeroLM,
...@@ -37,8 +40,8 @@ except Exception: ...@@ -37,8 +40,8 @@ except Exception:
LexiconDecoderOptions as _LexiconDecoderOptions, LexiconDecoderOptions as _LexiconDecoderOptions,
LexiconFreeDecoder as _LexiconFreeDecoder, LexiconFreeDecoder as _LexiconFreeDecoder,
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions, LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
LM as CTCDecoderLM, LM as _LM,
LMState as CTCDecoderLMState, LMState as _LMState,
SmearingMode as _SmearingMode, SmearingMode as _SmearingMode,
Trie as _Trie, Trie as _Trie,
ZeroLM as _ZeroLM, ZeroLM as _ZeroLM,
...@@ -113,6 +116,83 @@ class CTCHypothesis(NamedTuple): ...@@ -113,6 +116,83 @@ class CTCHypothesis(NamedTuple):
timesteps: torch.IntTensor timesteps: torch.IntTensor
class CTCDecoderLMState(_LMState):
"""Language model state.
:ivar Dict[int] children: Map of indices to LM states
"""
def child(self, usr_index: int):
"""Returns child corresponding to usr_index, or creates and returns a new state if input index
is not found.
Args:
usr_index (int): index corresponding to child state
Returns:
CTCDecoderLMState: child state corresponding to usr_index
"""
return super().child(usr_index)
def compare(self, state: CTCDecoderLMState):
"""Compare two language model states.
Args:
state (CTCDecoderLMState): LM state to compare against
Returns:
int: 0 if the states are the same, -1 if self is less, +1 if self is greater.
"""
pass
class CTCDecoderLM(_LM):
"""Language model base class for creating custom language models to use with the decoder."""
@abstractmethod
def start(self, start_with_nothing: bool):
"""Initialize or reset the language model.
Args:
start_with_nothing (bool): whether or not to start sentence with sil token.
Returns:
CTCDecoderLMState: starting state
"""
raise NotImplementedError
@abstractmethod
def score(self, state: CTCDecoderLMState, usr_token_idx: int):
"""Evaluate the language model based on the current LM state and new word.
Args:
state (CTCDecoderLMState): current LM state
usr_token_idx (int): index of the word
Returns:
Tuple[CTCDecoderLMState, float]
CTCDecoderLMState: new LM state
float: score
"""
raise NotImplementedError
@abstractmethod
def finish(self, state: CTCDecoderLMState):
"""Evaluate end for language model based on current LM state.
Args:
state (CTCDecoderLMState): current LM state
Returns:
(CTCDecoderLMState, float)
CTCDecoderLMState:
new LM state
float:
score
"""
raise NotImplementedError
class CTCDecoder: class CTCDecoder:
""" """
.. devices:: CPU .. devices:: CPU
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment