Commit 6b242c29 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Change underlying implementation of RNN-T hypothesis to tuple (#2339)

Summary:
PyTorch Lite, which is becoming a standard for mobile PyTorch usage, does not support containers containing custom classes. Consequently, because TorchAudio's RNN-T decoder currently returns and accepts lists of `Hypothesis` namedtuples, it is not compatible with PyTorch Lite. This PR resolves said incompatibility by changing the underlying implementation of `Hypothesis` to tuple.

Pull Request resolved: https://github.com/pytorch/audio/pull/2339

Reviewed By: nateanl

Differential Revision: D35806529

Pulled By: hwangjeff

fbshipit-source-id: 9cbae5504722390511d35e7f9966af2519ccede5
parent 9465b6bf
......@@ -91,8 +91,10 @@ RNNTBeamSearch
Hypothesis
^^^^^^^^^^
.. autoclass:: Hypothesis
.. container:: py attribute
.. autodata:: Hypothesis
:no-value:
Tacotron2
~~~~~~~~~
......
......@@ -50,20 +50,21 @@ def batch_by_token_count(idx_target_lengths, token_limit):
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos]
hypos_ids = [h.tokens[1:] for h in hypos]
hypos_score = [[math.exp(h.score)] for h in hypos]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids))
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
......
......@@ -67,7 +67,7 @@ def run_eval_streaming(args):
features, length = streaming_feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0]
transcript = token_processor(hypothesis.tokens, lstrip=False)
transcript = token_processor(hypothesis[0], lstrip=False)
print(transcript, end="", flush=True)
print()
......@@ -75,7 +75,7 @@ def run_eval_streaming(args):
with torch.no_grad():
features, length = feature_extractor(waveform)
hypos = decoder(features, length, 10)
print(token_processor(hypos[0].tokens))
print(token_processor(hypos[0][0]))
print()
......
......@@ -34,7 +34,7 @@ def _eval_subset(tedlium_path, subset, feature_extractor, decoder, token_process
features, length = feature_extractor(waveform)
hypos = decoder(features, length, 20)
hypothesis = hypos[0]
hypothesis = token_processor(hypothesis.tokens)
hypothesis = token_processor(hypothesis[0])
total_edit_distance += compute_word_level_distance(actual, hypothesis)
total_length += len(actual.split())
if idx % 100 == 0:
......
......@@ -169,20 +169,21 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ali = [h.alignment[1:] for h in hypos]
hypos_ids = [h.tokens[1:] for h in hypos]
hypos_score = [[math.exp(h.score)] for h in hypos]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids))
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
......
......@@ -221,7 +221,7 @@ class Pipeline:
features, length, self.beam_width, state=self.state, hypothesis=self.hypothesis
)
self.hypothesis = hypos[0]
transcript = self.token_processor(self.hypothesis.tokens, lstrip=False)
transcript = self.token_processor(self.hypothesis[0], lstrip=False)
return transcript
......
......@@ -215,7 +215,7 @@ def run_inference(num_iter=200):
features, length = feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0]
transcript = token_processor(hypothesis.tokens, lstrip=False)
transcript = token_processor(hypothesis[0], lstrip=False)
print(transcript, end="", flush=True)
chunks.append(chunk)
......
......@@ -14,5 +14,5 @@ def test_rnnt(bundle, sample_speech, expected):
waveform, _ = torchaudio.load(sample_speech)
features, length = feature_extractor(waveform.squeeze())
hypotheses = decoder(features, length, 10)
text = token_processor(hypotheses[0].tokens)
text = token_processor(hypotheses[0][0])
assert text == expected
from typing import Callable, Dict, List, Optional, NamedTuple, Tuple
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torchaudio.models import RNNT
......@@ -7,34 +7,38 @@ from torchaudio.models import RNNT
__all__ = ["Hypothesis", "RNNTBeamSearch"]
class Hypothesis(NamedTuple):
r"""Represents hypothesis generated by beam search decoder ``RNNTBeamSearch``.
:ivar List[int] tokens: Predicted sequence of tokens.
:ivar torch.Tensor predictor_out: Prediction network output.
:ivar List[List[torch.Tensor]] state: Prediction network internal state.
:ivar float score: Score of hypothesis.
:ivar List[int] alignment: Sequence of timesteps, with the i-th value mapping
to the i-th predicted token in ``tokens``.
:ivar int blank: Token index corresponding to blank token.
:ivar str key: Value used to determine equivalence in token sequences
between ``Hypothesis`` instances.
Hypothesis = Tuple[List[int], torch.Tensor, List[List[torch.Tensor]], float]
Hypothesis.__doc__ = """Hypothesis generated by RNN-T beam search decoder,
represented as tuple of (tokens, prediction network output, prediction network state, score).
"""
tokens: List[int]
predictor_out: torch.Tensor
state: List[List[torch.Tensor]]
score: float
alignment: List[int]
blank: int
key: str
def _get_hypo_tokens(hypo: Hypothesis) -> List[int]:
return hypo[0]
def _get_hypo_predictor_out(hypo: Hypothesis) -> torch.Tensor:
return hypo[1]
def _get_hypo_state(hypo: Hypothesis) -> List[List[torch.Tensor]]:
return hypo[2]
def _get_hypo_score(hypo: Hypothesis) -> float:
return hypo[3]
def _get_hypo_key(hypo: Hypothesis) -> str:
return str(hypo[0])
def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
states: List[List[torch.Tensor]] = []
for i in range(len(hypos[0].state)):
for i in range(len(_get_hypo_state(hypos[0]))):
batched_state_components: List[torch.Tensor] = []
for j in range(len(hypos[0].state[i])):
batched_state_components.append(torch.cat([hypo.state[i][j] for hypo in hypos]))
for j in range(len(_get_hypo_state(hypos[0])[i])):
batched_state_components.append(torch.cat([_get_hypo_state(hypo)[i][j] for hypo in hypos]))
states.append(batched_state_components)
return states
......@@ -45,7 +49,7 @@ def _slice_state(states: List[List[torch.Tensor]], idx: int, device: torch.devic
def _default_hypo_sort_key(hypo: Hypothesis) -> float:
return hypo.score / (len(hypo.tokens) + 1)
return _get_hypo_score(hypo) / (len(_get_hypo_tokens(hypo)) + 1)
def _compute_updated_scores(
......@@ -53,7 +57,7 @@ def _compute_updated_scores(
next_token_probs: torch.Tensor,
beam_width: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hypo_scores = torch.tensor([h.score for h in hypos]).unsqueeze(1)
hypo_scores = torch.tensor([_get_hypo_score(h) for h in hypos]).unsqueeze(1)
nonblank_scores = hypo_scores + next_token_probs[:, :-1] # [beam_width, num_tokens - 1]
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(beam_width)
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(nonblank_scores.shape[1], rounding_mode="trunc")
......@@ -63,7 +67,7 @@ def _compute_updated_scores(
def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
for i, elem in enumerate(hypo_list):
if hypo.key == elem.key:
if _get_hypo_key(hypo) == _get_hypo_key(elem):
del hypo_list[i]
break
......@@ -104,22 +108,19 @@ class RNNTBeamSearch(torch.nn.Module):
def _init_b_hypos(self, hypo: Optional[Hypothesis], device: torch.device) -> List[Hypothesis]:
if hypo is not None:
token = hypo.tokens[-1]
state = hypo.state
token = _get_hypo_tokens(hypo)[-1]
state = _get_hypo_state(hypo)
else:
token = self.blank
state = None
one_tensor = torch.tensor([1], device=device)
pred_out, _, pred_state = self.model.predict(torch.tensor([[token]], device=device), one_tensor, state)
init_hypo = Hypothesis(
tokens=[token],
predictor_out=pred_out[0].detach(),
state=pred_state,
score=0.0,
alignment=[-1],
blank=self.blank,
key=str([token]),
init_hypo = (
[token],
pred_out[0].detach(),
pred_state,
0.0,
)
return [init_hypo]
......@@ -127,7 +128,7 @@ class RNNTBeamSearch(torch.nn.Module):
self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
) -> torch.Tensor:
one_tensor = torch.tensor([1], device=device)
predictor_out = torch.stack([h.predictor_out for h in hypos], dim=0)
predictor_out = torch.stack([_get_hypo_predictor_out(h) for h in hypos], dim=0)
joined_out, _, _ = self.model.join(
enc_out,
one_tensor,
......@@ -146,27 +147,22 @@ class RNNTBeamSearch(torch.nn.Module):
) -> List[Hypothesis]:
for i in range(len(a_hypos)):
h_a = a_hypos[i]
append_blank_score = h_a.score + next_token_probs[i, -1]
if h_a.key in key_to_b_hypo:
h_b = key_to_b_hypo[h_a.key]
append_blank_score = _get_hypo_score(h_a) + next_token_probs[i, -1]
if _get_hypo_key(h_a) in key_to_b_hypo:
h_b = key_to_b_hypo[_get_hypo_key(h_a)]
_remove_hypo(h_b, b_hypos)
score = float(torch.tensor(h_b.score).logaddexp(append_blank_score))
alignment = h_a.alignment if h_b.score < h_a.score else h_b.alignment
score = float(torch.tensor(_get_hypo_score(h_b)).logaddexp(append_blank_score))
else:
score = float(append_blank_score)
alignment = h_a.alignment
h_b = Hypothesis(
tokens=h_a.tokens,
predictor_out=h_a.predictor_out,
state=h_a.state,
score=score,
alignment=alignment,
blank=self.blank,
key=h_a.key,
h_b = (
_get_hypo_tokens(h_a),
_get_hypo_predictor_out(h_a),
_get_hypo_state(h_a),
score,
)
b_hypos.append(h_b)
key_to_b_hypo[h_b.key] = h_b
_, sorted_idx = torch.tensor([hypo.score for hypo in b_hypos]).sort()
key_to_b_hypo[_get_hypo_key(h_b)] = h_b
_, sorted_idx = torch.tensor([_get_hypo_score(hypo) for hypo in b_hypos]).sort()
return [b_hypos[idx] for idx in sorted_idx]
def _gen_a_hypos(
......@@ -187,7 +183,7 @@ class RNNTBeamSearch(torch.nn.Module):
if len(b_hypos) < beam_width:
b_nbest_score = -float("inf")
else:
b_nbest_score = b_hypos[-beam_width].score
b_nbest_score = _get_hypo_score(b_hypos[-beam_width])
base_hypos: List[Hypothesis] = []
new_tokens: List[int] = []
......@@ -224,18 +220,8 @@ class RNNTBeamSearch(torch.nn.Module):
)
new_hypos: List[Hypothesis] = []
for i, h_a in enumerate(base_hypos):
new_tokens = h_a.tokens + [tokens[i]]
new_hypos.append(
Hypothesis(
tokens=new_tokens,
predictor_out=pred_out[i].detach(),
state=_slice_state(pred_states, i, device),
score=scores[i],
alignment=h_a.alignment + [t],
blank=self.blank,
key=str(new_tokens),
)
)
new_tokens = _get_hypo_tokens(h_a) + [tokens[i]]
new_hypos.append((new_tokens, pred_out[i].detach(), _slice_state(pred_states, i, device), scores[i]))
return new_hypos
def _search(
......@@ -258,12 +244,7 @@ class RNNTBeamSearch(torch.nn.Module):
while a_hypos:
next_token_probs = self._gen_next_token_probs(enc_out[:, t : t + 1], a_hypos, device)
next_token_probs = next_token_probs.cpu()
b_hypos = self._gen_b_hypos(
b_hypos,
a_hypos,
next_token_probs,
key_to_b_hypo,
)
b_hypos = self._gen_b_hypos(b_hypos, a_hypos, next_token_probs, key_to_b_hypo)
if symbols_current_t == self.step_max_tokens:
break
......
......@@ -194,7 +194,7 @@ class RNNTBundle:
>>> hypotheses = decoder(features, length, 10)
>>>
>>> # For top hypothesis, convert predicted tokens to text.
>>> text = token_processor(hypotheses[0].tokens)
>>> text = token_processor(hypotheses[0][0])
>>> print(text)
he hoped there would be stew for dinner turnips and carrots and bruised potatoes and fat mutton pieces to [...]
>>>
......@@ -219,7 +219,7 @@ class RNNTBundle:
>>> features, length = streaming_feature_extractor(segment)
>>> hypotheses, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
>>> hypothesis = hypotheses[0]
>>> transcript = token_processor(hypothesis.tokens)
>>> transcript = token_processor(hypothesis[0])
>>> if transcript:
>>> print(transcript, end=" ", flush=True)
he hoped there would be stew for dinner turn ips and car rots and bru 'd oes and fat mut ton pieces to [...]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment