"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "18fdcc94e55d8ca393be9d01b30246dbbca6f6af"
Commit 6fbc1e68 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add incremental decoding support to CTC decoder (#3594)

Summary:
Add incremental decoding support to CTC decoder.

Resolves https://github.com/pytorch/audio/issues/3574

Pull Request resolved: https://github.com/pytorch/audio/pull/3594

Reviewed By: nateanl

Differential Revision: D48940584

Pulled By: mthrok

fbshipit-source-id: 31871614008cf197cf3900f7183ec6cff34d2905
parent 3e7e696c
......@@ -387,6 +387,47 @@ print(f"WER: {beam_search_wer}")
# and “shoktd”.
#
######################################################################
# Incremental decoding
# ~~~~~~~~~~~~~~~~~~~~
#
# If the input speech is long, one can decode the emission in
# incremental manner.
#
# You need to first initialize the internal state of the decoder with
# :py:meth:`~torchaudio.models.decoder.CTCDecoder.decode_begin`.
beam_search_decoder.decode_begin()
######################################################################
# Then, you can pass emissions to
# :py:meth:`~torchaudio.models.decoder.CTCDecoder.decode_begin`.
# Here we use the same emission but pass it to the decoder one frame
# at a time.
for t in range(emission.size(1)):
beam_search_decoder.decode_step(emission[0, t:t + 1, :])
######################################################################
# Finally, finalize the internal state of the decoder, and retrieve the
# result.
beam_search_decoder.decode_end()
beam_search_result_inc = beam_search_decoder.get_final_hypothesis()
######################################################################
# The result of incremental decoding is identical to batch decoding.
#
beam_search_transcript_inc = " ".join(beam_search_result_inc[0].words).strip()
beam_search_wer_inc = torchaudio.functional.edit_distance(
actual_transcript, beam_search_result_inc[0].words) / len(actual_transcript)
print(f"Transcript: {beam_search_transcript_inc}")
print(f"WER: {beam_search_wer_inc}")
assert beam_search_result[0][0].words == beam_search_result_inc[0].words
assert beam_search_result[0][0].score == beam_search_result_inc[0].score
torch.testing.assert_close(beam_search_result[0][0].timesteps, beam_search_result_inc[0].timesteps)
######################################################################
# Timestep Alignments
......
......@@ -261,10 +261,102 @@ class CTCDecoder:
timesteps.append(i)
return torch.IntTensor(timesteps)
def decode_begin(self):
"""Initialize the internal state of the decoder.
See :py:meth:`decode_step` for the usage.
.. note::
This method is required only when performing online decoding.
It is not necessary when performing batch decoding with :py:meth:`__call__`.
"""
self.decoder.decode_begin()
def decode_end(self):
"""Finalize the internal state of the decoder.
See :py:meth:`decode_step` for the usage.
.. note::
This method is required only when performing online decoding.
It is not necessary when performing batch decoding with :py:meth:`__call__`.
"""
self.decoder.decode_end()
def decode_step(self, emissions: torch.FloatTensor):
"""Perform incremental decoding on top of the curent internal state.
.. note::
This method is required only when performing online decoding.
It is not necessary when performing batch decoding with :py:meth:`__call__`.
Args:
emissions (torch.FloatTensor): CPU tensor of shape `(frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model.
Example:
>>> decoder = torchaudio.models.decoder.ctc_decoder(...)
>>> decoder.decode_begin()
>>> decoder.decode_step(emission1)
>>> decoder.decode_step(emission2)
>>> decoder.decode_end()
>>> result = decoder.get_final_hypothesis()
"""
if emissions.dtype != torch.float32:
raise ValueError("emissions must be float32.")
if not emissions.is_cpu:
raise RuntimeError("emissions must be a CPU tensor.")
if not emissions.is_contiguous():
raise RuntimeError("emissions must be contiguous.")
if emissions.ndim != 2:
raise RuntimeError(f"emissions must be 2D. Found {emissions.shape}")
T, N = emissions.size()
self.decoder.decode_step(emissions.data_ptr(), T, N)
def _to_hypo(self, results) -> List[CTCHypothesis]:
return [
CTCHypothesis(
tokens=self._get_tokens(result.tokens),
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
score=result.score,
timesteps=self._get_timesteps(result.tokens),
)
for result in results
]
def get_final_hypothesis(self) -> List[CTCHypothesis]:
"""Get the final hypothesis
Returns:
List[CTCHypothesis]:
List of sorted best hypotheses.
.. note::
This method is required only when performing online decoding.
It is not necessary when performing batch decoding with :py:meth:`__call__`.
"""
results = self.decoder.get_all_final_hypothesis()
return self._to_hypo(results[: self.nbest])
def __call__(
self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
) -> List[List[CTCHypothesis]]:
"""
Performs batched offline decoding.
.. note::
This method performs offline decoding in one go. To perform incremental decoding,
please refer to :py:meth:`decode_step`.
Args:
emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model.
......@@ -279,13 +371,16 @@ class CTCDecoder:
if emissions.dtype != torch.float32:
raise ValueError("emissions must be float32.")
if emissions.is_cuda:
if not emissions.is_cpu:
raise RuntimeError("emissions must be a CPU tensor.")
if not emissions.is_contiguous():
raise RuntimeError("emissions must be contiguous.")
if lengths is not None and lengths.is_cuda:
if emissions.ndim != 3:
raise RuntimeError(f"emissions must be 3D. Found {emissions.shape}")
if lengths is not None and not lengths.is_cpu:
raise RuntimeError("lengths must be a CPU tensor.")
B, T, N = emissions.size()
......@@ -298,20 +393,7 @@ class CTCDecoder:
for b in range(B):
emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, lengths[b], N)
nbest_results = results[: self.nbest]
hypos.append(
[
CTCHypothesis(
tokens=self._get_tokens(result.tokens),
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
score=result.score,
timesteps=self._get_timesteps(result.tokens),
)
for result in nbest_results
]
)
hypos.append(self._to_hypo(results[: self.nbest]))
return hypos
def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment