Commit d43ce015 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add CTC decoder timesteps (#2184)

Summary:
add timesteps field to CTC decoder hypotheses, corresponding to the time step of occurrences of non-blank tokens

Pull Request resolved: https://github.com/pytorch/audio/pull/2184

Reviewed By: mthrok

Differential Revision: D33905530

Pulled By: carolineechen

fbshipit-source-id: c575d25655fcf252754ee3c2447949a4c059461a
parent bb73934f
...@@ -64,6 +64,37 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): ...@@ -64,6 +64,37 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
results = decoder(emissions) results = decoder(emissions)
self.assertEqual(len(results), emissions.shape[0]) self.assertEqual(len(results), emissions.shape[0])
def test_timesteps_shape(self):
"""Each token should correspond with a timestep"""
emissions = self._get_emissions()
decoder = self._get_decoder()
results = decoder(emissions)
for i in range(emissions.shape[0]):
result = results[i][0]
self.assertEqual(result.tokens.shape, result.timesteps.shape)
def test_get_timesteps(self):
unprocessed_tokens = torch.tensor([2, 2, 0, 3, 3, 3, 0, 3])
decoder = self._get_decoder()
timesteps = decoder._get_timesteps(unprocessed_tokens)
expected = [0, 3, 7]
self.assertEqual(timesteps, expected)
def test_get_tokens_and_idxs(self):
unprocessed_tokens = torch.tensor([2, 2, 0, 3, 3, 3, 0, 3]) # ["f", "f", "-", "o", "o", "o", "-", "o"]
decoder = self._get_decoder()
token_ids = decoder._get_tokens(unprocessed_tokens)
tokens = decoder.idxs_to_tokens(token_ids)
expected_ids = [2, 3, 3]
self.assertEqual(token_ids, expected_ids)
expected_tokens = ["f", "o", "o"]
self.assertEqual(tokens, expected_tokens)
@parameterized.expand([(get_asset_path("decoder/tokens.txt"),), (["-", "|", "f", "o", "b", "a", "r"],)]) @parameterized.expand([(get_asset_path("decoder/tokens.txt"),), (["-", "|", "f", "o", "b", "a", "r"],)])
def test_index_to_tokens(self, tokens): def test_index_to_tokens(self, tokens):
# decoder tokens: '-' '|' 'f' 'o' 'b' 'a' 'r' # decoder tokens: '-' '|' 'f' 'o' 'b' 'a' 'r'
......
...@@ -23,13 +23,17 @@ __all__ = ["Hypothesis", "LexiconDecoder", "lexicon_decoder"] ...@@ -23,13 +23,17 @@ __all__ = ["Hypothesis", "LexiconDecoder", "lexicon_decoder"]
class Hypothesis(NamedTuple): class Hypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func`LexiconDecoder`. r"""Represents hypothesis generated by CTC beam search decoder :py:func`LexiconDecoder`.
:ivar torch.LongTensor tokens: Predicted sequence of token IDs :ivar torch.LongTensor tokens: Predicted sequence of token IDs. Shape `(L, )`, where
`L` is the length of the output sequence
:ivar List[str] words: List of predicted words :ivar List[str] words: List of predicted words
:ivar float score: Score corresponding to hypothesis :ivar float score: Score corresponding to hypothesis
:ivar torch.IntTensor timesteps: Timesteps corresponding to the tokens. Shape `(L, )`,
where `L` is the length of the output sequence
""" """
tokens: torch.LongTensor tokens: torch.LongTensor
words: List[str] words: List[str]
score: float score: float
timesteps: torch.IntTensor
class LexiconDecoder: class LexiconDecoder:
...@@ -107,6 +111,17 @@ class LexiconDecoder: ...@@ -107,6 +111,17 @@ class LexiconDecoder:
idxs = filter(lambda x: x != self.blank, idxs) idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs)) return torch.LongTensor(list(idxs))
def _get_timesteps(self, idxs: torch.IntTensor) -> torch.IntTensor:
"""Returns frame numbers corresponding to non-blank tokens."""
timesteps = []
for i, idx in enumerate(idxs):
if idx == self.blank:
continue
if i == 0 or idx != idxs[i - 1]:
timesteps.append(i)
return torch.IntTensor(timesteps)
def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> List[List[Hypothesis]]: def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> List[List[Hypothesis]]:
# Overriding the signature so that the return type is correct on Sphinx # Overriding the signature so that the return type is correct on Sphinx
"""__call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> \ """__call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> \
...@@ -139,9 +154,10 @@ class LexiconDecoder: ...@@ -139,9 +154,10 @@ class LexiconDecoder:
hypos.append( hypos.append(
[ [
Hypothesis( Hypothesis(
tokens=self._get_tokens(result.tokens), # token ids tokens=self._get_tokens(result.tokens),
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0], # words words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
score=result.score, # score score=result.score,
timesteps=self._get_timesteps(result.tokens),
) )
for result in nbest_results for result in nbest_results
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment