Commit 6b242c29 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Change underlying implementation of RNN-T hypothesis to tuple (#2339)

Summary:
PyTorch Lite, which is becoming a standard for mobile PyTorch usage, does not support containers containing custom classes. Consequently, because TorchAudio's RNN-T decoder currently returns and accepts lists of `Hypothesis` namedtuples, it is not compatible with PyTorch Lite. This PR resolves said incompatibility by changing the underlying implementation of `Hypothesis` to tuple.

Pull Request resolved: https://github.com/pytorch/audio/pull/2339

Reviewed By: nateanl

Differential Revision: D35806529

Pulled By: hwangjeff

fbshipit-source-id: 9cbae5504722390511d35e7f9966af2519ccede5
parent 9465b6bf
...@@ -91,8 +91,10 @@ RNNTBeamSearch ...@@ -91,8 +91,10 @@ RNNTBeamSearch
Hypothesis Hypothesis
^^^^^^^^^^ ^^^^^^^^^^
.. autoclass:: Hypothesis .. container:: py attribute
.. autodata:: Hypothesis
:no-value:
Tacotron2 Tacotron2
~~~~~~~~~ ~~~~~~~~~
......
...@@ -50,20 +50,21 @@ def batch_by_token_count(idx_target_lengths, token_limit): ...@@ -50,20 +50,21 @@ def batch_by_token_count(idx_target_lengths, token_limit):
def post_process_hypos( def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]: ) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [ post_process_remove_list = [
sp_model.unk_id(), sp_model.unk_id(),
sp_model.eos_id(), sp_model.eos_id(),
sp_model.pad_id(), sp_model.pad_id(),
] ]
filtered_hypo_tokens = [ filtered_hypo_tokens = [
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos [token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
] ]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens] hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos] hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_ids = [h.tokens[1:] for h in hypos] hypos_score = [[math.exp(h[score_idx])] for h in hypos]
hypos_score = [[math.exp(h.score)] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids)) nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch return nbest_batch
......
...@@ -67,7 +67,7 @@ def run_eval_streaming(args): ...@@ -67,7 +67,7 @@ def run_eval_streaming(args):
features, length = streaming_feature_extractor(segment) features, length = streaming_feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis) hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0] hypothesis = hypos[0]
transcript = token_processor(hypothesis.tokens, lstrip=False) transcript = token_processor(hypothesis[0], lstrip=False)
print(transcript, end="", flush=True) print(transcript, end="", flush=True)
print() print()
...@@ -75,7 +75,7 @@ def run_eval_streaming(args): ...@@ -75,7 +75,7 @@ def run_eval_streaming(args):
with torch.no_grad(): with torch.no_grad():
features, length = feature_extractor(waveform) features, length = feature_extractor(waveform)
hypos = decoder(features, length, 10) hypos = decoder(features, length, 10)
print(token_processor(hypos[0].tokens)) print(token_processor(hypos[0][0]))
print() print()
......
...@@ -34,7 +34,7 @@ def _eval_subset(tedlium_path, subset, feature_extractor, decoder, token_process ...@@ -34,7 +34,7 @@ def _eval_subset(tedlium_path, subset, feature_extractor, decoder, token_process
features, length = feature_extractor(waveform) features, length = feature_extractor(waveform)
hypos = decoder(features, length, 20) hypos = decoder(features, length, 20)
hypothesis = hypos[0] hypothesis = hypos[0]
hypothesis = token_processor(hypothesis.tokens) hypothesis = token_processor(hypothesis[0])
total_edit_distance += compute_word_level_distance(actual, hypothesis) total_edit_distance += compute_word_level_distance(actual, hypothesis)
total_length += len(actual.split()) total_length += len(actual.split())
if idx % 100 == 0: if idx % 100 == 0:
......
...@@ -169,20 +169,21 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler): ...@@ -169,20 +169,21 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def post_process_hypos( def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]: ) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [ post_process_remove_list = [
sp_model.unk_id(), sp_model.unk_id(),
sp_model.eos_id(), sp_model.eos_id(),
sp_model.pad_id(), sp_model.pad_id(),
] ]
filtered_hypo_tokens = [ filtered_hypo_tokens = [
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos [token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
] ]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens] hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos] hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_ids = [h.tokens[1:] for h in hypos] hypos_score = [[math.exp(h[score_idx])] for h in hypos]
hypos_score = [[math.exp(h.score)] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids)) nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch return nbest_batch
......
...@@ -221,7 +221,7 @@ class Pipeline: ...@@ -221,7 +221,7 @@ class Pipeline:
features, length, self.beam_width, state=self.state, hypothesis=self.hypothesis features, length, self.beam_width, state=self.state, hypothesis=self.hypothesis
) )
self.hypothesis = hypos[0] self.hypothesis = hypos[0]
transcript = self.token_processor(self.hypothesis.tokens, lstrip=False) transcript = self.token_processor(self.hypothesis[0], lstrip=False)
return transcript return transcript
......
...@@ -215,7 +215,7 @@ def run_inference(num_iter=200): ...@@ -215,7 +215,7 @@ def run_inference(num_iter=200):
features, length = feature_extractor(segment) features, length = feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis) hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0] hypothesis = hypos[0]
transcript = token_processor(hypothesis.tokens, lstrip=False) transcript = token_processor(hypothesis[0], lstrip=False)
print(transcript, end="", flush=True) print(transcript, end="", flush=True)
chunks.append(chunk) chunks.append(chunk)
......
...@@ -14,5 +14,5 @@ def test_rnnt(bundle, sample_speech, expected): ...@@ -14,5 +14,5 @@ def test_rnnt(bundle, sample_speech, expected):
waveform, _ = torchaudio.load(sample_speech) waveform, _ = torchaudio.load(sample_speech)
features, length = feature_extractor(waveform.squeeze()) features, length = feature_extractor(waveform.squeeze())
hypotheses = decoder(features, length, 10) hypotheses = decoder(features, length, 10)
text = token_processor(hypotheses[0].tokens) text = token_processor(hypotheses[0][0])
assert text == expected assert text == expected
from typing import Callable, Dict, List, Optional, NamedTuple, Tuple from typing import Callable, Dict, List, Optional, Tuple
import torch import torch
from torchaudio.models import RNNT from torchaudio.models import RNNT
...@@ -7,34 +7,38 @@ from torchaudio.models import RNNT ...@@ -7,34 +7,38 @@ from torchaudio.models import RNNT
__all__ = ["Hypothesis", "RNNTBeamSearch"] __all__ = ["Hypothesis", "RNNTBeamSearch"]
class Hypothesis(NamedTuple): Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float]
r"""Represents hypothesis generated by beam search decoder ``RNNTBeamSearch``. Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
represented as tuple of (tokens, prediction network output, prediction network state, score).
:ivar List[int] tokens: Predicted sequence of tokens.
:ivar torch.Tensor predictor_out: Prediction network output.
:ivar List[List[torch.Tensor]] state: Prediction network internal state.
:ivar float score: Score of hypothesis.
:ivar List[int] alignment: Sequence of timesteps, with the i-th value mapping
to the i-th predicted token in ``tokens``.
:ivar int blank: Token index corresponding to blank token.
:ivar str key: Value used to determine equivalence in token sequences
between ``Hypothesis`` instances.
""" """
tokens: List[int]
predictor_out: torch.Tensor
state: List[List[torch.Tensor]] def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
score: float return hypo[0]
alignment: List[int]
blank: int
key: str def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
return hypo[1]
def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
return hypo[2]
def _get_hypo_score(hypo: Hypothesis) -> float:
return hypo[3]
def _get_hypo_key(hypo: Hypothesis) -> str:
return str(hypo[0])
def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]: def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
states: List[List[torch.Tensor]] = [] states: List[List[torch.Tensor]] = []
for i in range(len(hypos[0].state)): for i in range(len(_get_hypo_state(hypos[0]))):
batched_state_components: List[torch.Tensor] = [] batched_state_components: List[torch.Tensor] = []
for j in range(len(hypos[0].state[i])): for j in range(len(_get_hypo_state(hypos[0])[i])):
batched_state_components.append(torch.cat([hypo.state[i][j] for hypo in hypos])) batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
states.append(batched_state_components) states.append(batched_state_components)
return states return states
...@@ -45,7 +49,7 @@ def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.devic ...@@ -45,7 +49,7 @@ def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.devic
def _default_hypo_sort_key(hypo: Hypothesis) -> float: def _default_hypo_sort_key(hypo: Hypothesis) -> float:
return hypo.score / (len(hypo.tokens) + 1) return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
def _compute_updated_scores( def _compute_updated_scores(
...@@ -53,7 +57,7 @@ def _compute_updated_scores( ...@@ -53,7 +57,7 @@ def _compute_updated_scores(
next_token_probs: torch.Tensor, next_token_probs: torch.Tensor,
beam_width: int, beam_width: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hypo_scores = torch.tensor([h.score for h in hypos]).unsqueeze(1) hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1] nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width) nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc") nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
...@@ -63,7 +67,7 @@ def _compute_updated_scores( ...@@ -63,7 +67,7 @@ def _compute_updated_scores(
def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None: def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
for i, elem in enumerate(hypo_list): for i, elem in enumerate(hypo_list):
if hypo.key == elem.key: if _get_hypo_key(hypo) == _get_hypo_key(elem):
del hypo_list[i] del hypo_list[i]
break break
...@@ -104,22 +108,19 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -104,22 +108,19 @@ class RNNTBeamSearch(torch.nn.Module):
def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]: def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
if hypo is not None: if hypo is not None:
token = hypo.tokens[-1] token = _get_hypo_tokens(hypo)[-1]
state = hypo.state state = _get_hypo_state(hypo)
else: else:
token = self.blank token = self.blank
state = None state = None
one_tensor = torch.tensor([1], device=device) one_tensor = torch.tensor([1], device=device)
pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state) pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
init_hypo = Hypothesis( init_hypo = (
tokens=[token], [token],
predictor_out=pred_out[0].detach(), pred_out[0].detach(),
state=pred_state, pred_state,
score=0.0, 0.0,
alignment=[-1],
blank=self.blank,
key=str([token]),
) )
return [init_hypo] return [init_hypo]
...@@ -127,7 +128,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -127,7 +128,7 @@ class RNNTBeamSearch(torch.nn.Module):
self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
) -> torch.Tensor: ) -> torch.Tensor:
one_tensor = torch.tensor([1], device=device) one_tensor = torch.tensor([1], device=device)
predictor_out = torch.stack([h.predictor_out for h in hypos], dim=0) predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
joined_out, _, _ = self.model.join( joined_out, _, _ = self.model.join(
enc_out, enc_out,
one_tensor, one_tensor,
...@@ -146,27 +147,22 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -146,27 +147,22 @@ class RNNTBeamSearch(torch.nn.Module):
) -> List[Hypothesis]: ) -> List[Hypothesis]:
for i in range(len(a_hypos)): for i in range(len(a_hypos)):
h_a = a_hypos[i] h_a = a_hypos[i]
append_blank_score = h_a.score + next_token_probs[i, -1] append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
if h_a.key in key_to_b_hypo: if _get_hypo_key(h_a) in key_to_b_hypo:
h_b = key_to_b_hypo[h_a.key] h_b = key_to_b_hypo[_get_hypo_key(h_a)]
_remove_hypo(h_b, b_hypos) _remove_hypo(h_b, b_hypos)
score = float(torch.tensor(h_b.score).logaddexp(append_blank_score)) score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
alignment = h_a.alignment if h_b.score < h_a.score else h_b.alignment
else: else:
score = float(append_blank_score) score = float(append_blank_score)
alignment = h_a.alignment h_b = (
h_b = Hypothesis( _get_hypo_tokens(h_a),
tokens=h_a.tokens, _get_hypo_predictor_out(h_a),
predictor_out=h_a.predictor_out, _get_hypo_state(h_a),
state=h_a.state, score,
score=score,
alignment=alignment,
blank=self.blank,
key=h_a.key,
) )
b_hypos.append(h_b) b_hypos.append(h_b)
key_to_b_hypo[h_b.key] = h_b key_to_b_hypo[_get_hypo_key(h_b)] = h_b
_, sorted_idx = torch.tensor([hypo.score for hypo in b_hypos]).sort() _, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
return [b_hypos[idx] for idx in sorted_idx] return [b_hypos[idx] for idx in sorted_idx]
def _gen_a_hypos( def _gen_a_hypos(
...@@ -187,7 +183,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -187,7 +183,7 @@ class RNNTBeamSearch(torch.nn.Module):
if len(b_hypos) < beam_width: if len(b_hypos) < beam_width:
b_nbest_score = -float("inf") b_nbest_score = -float("inf")
else: else:
b_nbest_score = b_hypos[-beam_width].score b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
base_hypos: List[Hypothesis] = [] base_hypos: List[Hypothesis] = []
new_tokens: List[int] = [] new_tokens: List[int] = []
...@@ -224,18 +220,8 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -224,18 +220,8 @@ class RNNTBeamSearch(torch.nn.Module):
) )
new_hypos: List[Hypothesis] = [] new_hypos: List[Hypothesis] = []
for i, h_a in enumerate(base_hypos): for i, h_a in enumerate(base_hypos):
new_tokens = h_a.tokens + [tokens[i]] new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
new_hypos.append( new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i]))
Hypothesis(
tokens=new_tokens,
predictor_out=pred_out[i].detach(),
state=_slice_state(pred_states, i, device),
score=scores[i],
alignment=h_a.alignment + [t],
blank=self.blank,
key=str(new_tokens),
)
)
return new_hypos return new_hypos
def _search( def _search(
...@@ -258,12 +244,7 @@ class RNNTBeamSearch(torch.nn.Module): ...@@ -258,12 +244,7 @@ class RNNTBeamSearch(torch.nn.Module):
while a_hypos: while a_hypos:
next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device) next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
next_token_probs = next_token_probs.cpu() next_token_probs = next_token_probs.cpu()
b_hypos = self._gen_b_hypos( b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
b_hypos,
a_hypos,
next_token_probs,
key_to_b_hypo,
)
if symbols_current_t == self.step_max_tokens: if symbols_current_t == self.step_max_tokens:
break break
......
...@@ -194,7 +194,7 @@ class RNNTBundle: ...@@ -194,7 +194,7 @@ class RNNTBundle:
>>> hypotheses = decoder(features, length, 10) >>> hypotheses = decoder(features, length, 10)
>>> >>>
>>> # For top hypothesis, convert predicted tokens to text. >>> # For top hypothesis, convert predicted tokens to text.
>>> text = token_processor(hypotheses[0].tokens) >>> text = token_processor(hypotheses[0][0])
>>> print(text) >>> print(text)
he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to [...] he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to [...]
>>> >>>
...@@ -219,7 +219,7 @@ class RNNTBundle: ...@@ -219,7 +219,7 @@ class RNNTBundle:
>>> features, length = streaming_feature_extractor(segment) >>> features, length = streaming_feature_extractor(segment)
>>> hypotheses, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis) >>> hypotheses, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
>>> hypothesis = hypotheses[0] >>> hypothesis = hypotheses[0]
>>> transcript = token_processor(hypothesis.tokens) >>> transcript = token_processor(hypothesis[0])
>>> if transcript: >>> if transcript:
>>> print(transcript, end=" ", flush=True) >>> print(transcript, end=" ", flush=True)
he hoped there would be stew for dinner turn ips and car rots and bru 'd oes and fat mut ton pieces to [...] he hoped there would be stew for dinner turn ips and car rots and bru 'd oes and fat mut ton pieces to [...]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment