Commit 60a85b50 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add RNN-T beam search decoder (#2028)

Summary:
Adds beam search decoder for RNN-T implementation ``torchaudio.prototype.RNNT`` that is TorchScript-able and supports both streaming and non-streaming inference.

Pull Request resolved: https://github.com/pytorch/audio/pull/2028

Reviewed By: mthrok

Differential Revision: D32627919

Pulled By: hwangjeff

fbshipit-source-id: aab99e346d6514a3207a9fb69d4b42978b4cdbbd
parent 9c9aef88
...@@ -49,6 +49,22 @@ emformer_rnnt_model ...@@ -49,6 +49,22 @@ emformer_rnnt_model
.. autofunction:: emformer_rnnt_model .. autofunction:: emformer_rnnt_model
RNNTBeamSearch
~~~~~~~~~~~~~~
.. autoclass:: RNNTBeamSearch
.. automethod:: forward
.. automethod:: infer
Hypothesis
~~~~~~~~~~
.. autoclass:: Hypothesis
References References
~~~~~~~~~~ ~~~~~~~~~~
......
import torch
from torchaudio_unittest.prototype.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
from torchaudio_unittest.common_utils import PytorchTestCase
class RNNTBeamSearchFloat32CPUTest(RNNTBeamSearchTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cpu")
class RNNTBeamSearchFloat64CPUTest(RNNTBeamSearchTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.prototype.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
@skipIfNoCuda
class RNNTBeamSearchFloat32GPUTest(RNNTBeamSearchTestImpl, PytorchTestCase):
dtype = torch.float32
device = torch.device("cuda")
@skipIfNoCuda
class RNNTBeamSearchFloat64GPUTest(RNNTBeamSearchTestImpl, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import torch
from torchaudio.prototype import RNNTBeamSearch, emformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
class RNNTBeamSearchTestImpl(TestBaseMixin):
def _get_input_config(self):
model_config = self._get_model_config()
return {
"batch_size": 1,
"max_input_length": 61,
"num_symbols": model_config["num_symbols"],
"input_dim": model_config["input_dim"],
"right_context_length": model_config["right_context_length"],
"segment_length": model_config["segment_length"],
}
def _get_model_config(self):
return {
"input_dim": 80,
"encoding_dim": 128,
"num_symbols": 256,
"segment_length": 16,
"right_context_length": 4,
"time_reduction_input_dim": 128,
"time_reduction_stride": 4,
"transformer_num_heads": 4,
"transformer_ffn_dim": 64,
"transformer_num_layers": 3,
"transformer_dropout": 0.0,
"transformer_activation": "relu",
"transformer_left_context_length": 30,
"transformer_max_memory_size": 0,
"transformer_weight_init_scale_strategy": "depthwise",
"transformer_tanh_on_mem": True,
"symbol_embedding_dim": 64,
"num_lstm_layers": 2,
"lstm_layer_norm": True,
"lstm_layer_norm_epsilon": 1e-3,
"lstm_dropout": 0.0,
}
def _get_model(self):
return (
emformer_rnnt_model(**self._get_model_config())
.to(device=self.device, dtype=self.dtype)
.eval()
)
def test_torchscript_consistency_forward(self):
r"""Verify that scripting RNNTBeamSearch does not change the behavior of method `forward`."""
torch.random.manual_seed(31)
input_config = self._get_input_config()
batch_size = input_config["batch_size"]
max_input_length = input_config["max_input_length"]
right_context_length = input_config["right_context_length"]
input_dim = input_config["input_dim"]
num_symbols = input_config["num_symbols"]
blank_idx = num_symbols - 1
beam_width = 5
input = torch.rand(
batch_size, max_input_length + right_context_length, input_dim
).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(1, max_input_length + 1, (batch_size,)).to(
device=self.device, dtype=torch.int32
)
model = self._get_model()
beam_search = RNNTBeamSearch(model, blank_idx)
scripted = torch_script(beam_search)
res = beam_search(input, lengths, beam_width)
scripted_res = scripted(input, lengths, beam_width)
self.assertEqual(res, scripted_res)
def test_torchscript_consistency_infer(self):
r"""Verify that scripting RNNTBeamSearch does not change the behavior of method `infer`."""
torch.random.manual_seed(31)
input_config = self._get_input_config()
segment_length = input_config["segment_length"]
right_context_length = input_config["right_context_length"]
input_dim = input_config["input_dim"]
num_symbols = input_config["num_symbols"]
blank_idx = num_symbols - 1
beam_width = 5
input = torch.rand(
segment_length + right_context_length, input_dim
).to(device=self.device, dtype=self.dtype)
lengths = torch.randint(
1, segment_length + right_context_length + 1, ()
).to(device=self.device, dtype=torch.int32)
model = self._get_model()
state, hypo = None, None
scripted_state, scripted_hypo = None, None
for _ in range(2):
beam_search = RNNTBeamSearch(model, blank_idx)
scripted = torch_script(beam_search)
res = beam_search.infer(input, lengths, beam_width, state=state, hypothesis=hypo)
scripted_res = scripted.infer(
input, lengths, beam_width, state=scripted_state, hypothesis=scripted_hypo
)
self.assertEqual(res, scripted_res)
state = res[1]
hypo = res[0][0]
scripted_state = scripted_res[1]
scripted_hypo = scripted_res[0][0]
from .emformer import Emformer from .emformer import Emformer
from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
__all__ = ["Emformer", "RNNT", "emformer_rnnt_base", "emformer_rnnt_model"] __all__ = [
"Emformer",
"Hypothesis",
"RNNT",
"RNNTBeamSearch",
"emformer_rnnt_base",
"emformer_rnnt_model",
]
from typing import Callable, Dict, List, Optional, NamedTuple, Tuple
import torch
from .rnnt import RNNT
__all__ = ["Hypothesis", "RNNTBeamSearch"]
class Hypothesis(NamedTuple):
r"""Represents hypothesis generated by beam search decoder ``RNNTBeamSearch``.
:ivar List[int] tokens: Predicted sequence of tokens.
:ivar torch.Tensor predictor_out: Prediction network output.
:ivar List[List[torch.Tensor]] state: Prediction network internal state.
:ivar float score: Score of hypothesis.
:ivar List[int] alignment: Sequence of timesteps, with the i-th value mapping
to the i-th predicted token in ``tokens``.
:ivar int blank: Token index corresponding to blank token.
:ivar str key: Value used to determine equivalence in token sequences
between ``Hypothesis`` instances.
"""
tokens: List[int]
predictor_out: torch.Tensor
state: List[List[torch.Tensor]]
score: float
alignment: List[int]
blank: int
key: str
def _batch_state(hypos: List[Hypothesis]) -> List[List[torch.Tensor]]:
states: List[List[torch.Tensor]] = []
for i in range(len(hypos[0].state)):
batched_state_components: List[torch.Tensor] = []
for j in range(len(hypos[0].state[i])):
batched_state_components.append(
torch.cat([hypo.state[i][j] for hypo in hypos])
)
states.append(batched_state_components)
return states
def _slice_state(
states: List[List[torch.Tensor]], idx: int, device: torch.device
) -> List[List[torch.Tensor]]:
idx_tensor = torch.tensor([idx], device=device)
return [
[state.index_select(0, idx_tensor) for state in state_tuple]
for state_tuple in states
]
def _default_hypo_sort_key(hypo: Hypothesis) -> float:
return hypo.score / (len(hypo.tokens) + 1)
def _compute_updated_scores(
hypos: List[Hypothesis], next_token_probs: torch.Tensor, beam_width: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hypo_scores = torch.tensor([h.score for h in hypos]).unsqueeze(1)
nonblank_scores = (
hypo_scores + next_token_probs[:, :-1]
) # [beam_width, num_tokens - 1]
nonblank_nbest_scores, nonblank_nbest_idx = nonblank_scores.reshape(-1).topk(
beam_width
)
nonblank_nbest_hypo_idx = nonblank_nbest_idx.div(
nonblank_scores.shape[1], rounding_mode="trunc"
)
nonblank_nbest_token = nonblank_nbest_idx % nonblank_scores.shape[1]
return nonblank_nbest_scores, nonblank_nbest_hypo_idx, nonblank_nbest_token
def _remove_hypo(hypo: Hypothesis, hypo_list: List[Hypothesis]) -> None:
for i, elem in enumerate(hypo_list):
if hypo.key == elem.key:
del hypo_list[i]
break
class RNNTBeamSearch(torch.nn.Module):
r"""Beam search decoder for RNN-T model.
Args:
model (RNNT): RNN-T model to use.
blank (int): index of blank token in vocabulary.
temperature (float, optional): temperature to apply to joint network output.
Larger values yield more uniform samples. (Default: 1.0)
hypo_sort_key (Callable[[Hypothesis], float] or None, optional): callable that computes a score
for a given hypothesis to rank hypotheses by. If ``None``, defaults to callable that returns
hypothesis score normalized by token sequence length. (Default: None)
step_max_tokens (int, optional): maximum number of tokens to emit per input time step. (Default: 100)
"""
def __init__(
self,
model: RNNT,
blank: int,
temperature: float = 1.0,
hypo_sort_key: Optional[Callable[[Hypothesis], float]] = None,
step_max_tokens: int = 100,
) -> None:
super().__init__()
self.model = model
self.blank = blank
self.temperature = temperature
if hypo_sort_key is None:
self.hypo_sort_key = _default_hypo_sort_key
else:
self.hypo_sort_key = hypo_sort_key
self.step_max_tokens = step_max_tokens
def _init_b_hypos(
self, hypo: Optional[Hypothesis], device: torch.device
) -> List[Hypothesis]:
if hypo is not None:
token = hypo.tokens[-1]
state = hypo.state
else:
token = self.blank
state = None
one_tensor = torch.tensor([1], device=device)
pred_out, _, pred_state = self.model.predict(
torch.tensor([[token]], device=device), one_tensor, state
)
init_hypo = Hypothesis(
tokens=[token],
predictor_out=pred_out[0].detach(),
state=pred_state,
score=0.0,
alignment=[-1],
blank=self.blank,
key=str([token]),
)
return [init_hypo]
def _gen_next_token_probs(
self, enc_out: torch.Tensor, hypos: List[Hypothesis], device: torch.device
) -> torch.Tensor:
one_tensor = torch.tensor([1], device=device)
predictor_out = torch.stack([h.predictor_out for h in hypos], dim=0)
joined_out, _, _ = self.model.join(
enc_out,
one_tensor,
predictor_out,
torch.tensor([1] * len(hypos), device=device),
) # [beam_width, 1, 1, num_tokens]
joined_out = torch.nn.functional.log_softmax(
joined_out / self.temperature, dim=3
)
joined_out[:, :, :, :4].add_(-99999) # blank out invalid tokens
return joined_out[:, 0, 0]
def _gen_b_hypos(
self,
b_hypos: List[Hypothesis],
a_hypos: List[Hypothesis],
next_token_probs: torch.Tensor,
key_to_b_hypo: Dict[str, Hypothesis],
) -> List[Hypothesis]:
for i in range(len(a_hypos)):
h_a = a_hypos[i]
append_blank_score = h_a.score + next_token_probs[i, -1]
if h_a.key in key_to_b_hypo:
h_b = key_to_b_hypo[h_a.key]
_remove_hypo(h_b, b_hypos)
score = float(torch.tensor(h_b.score).logaddexp(append_blank_score))
alignment = h_a.alignment if h_b.score < h_a.score else h_b.alignment
else:
score = float(append_blank_score)
alignment = h_a.alignment
h_b = Hypothesis(
tokens=h_a.tokens,
predictor_out=h_a.predictor_out,
state=h_a.state,
score=score,
alignment=alignment,
blank=self.blank,
key=h_a.key,
)
b_hypos.append(h_b)
key_to_b_hypo[h_b.key] = h_b
_, sorted_idx = torch.tensor([hypo.score for hypo in b_hypos]).sort()
return [b_hypos[idx] for idx in sorted_idx]
def _gen_a_hypos(
self,
a_hypos: List[Hypothesis],
b_hypos: List[Hypothesis],
next_token_probs: torch.Tensor,
t: int,
beam_width: int,
device: torch.device,
) -> List[Hypothesis]:
(
nonblank_nbest_scores,
nonblank_nbest_hypo_idx,
nonblank_nbest_token,
) = _compute_updated_scores(a_hypos, next_token_probs, beam_width)
if len(b_hypos) < beam_width:
b_nbest_score = -float("inf")
else:
b_nbest_score = b_hypos[-beam_width].score
base_hypos: List[Hypothesis] = []
new_tokens: List[int] = []
new_scores: List[float] = []
for i in range(beam_width):
score = float(nonblank_nbest_scores[i])
if score > b_nbest_score:
a_hypo_idx = int(nonblank_nbest_hypo_idx[i])
base_hypos.append(a_hypos[a_hypo_idx])
new_tokens.append(int(nonblank_nbest_token[i]))
new_scores.append(score)
if base_hypos:
new_hypos = self._gen_new_hypos(
base_hypos, new_tokens, new_scores, t, device
)
else:
new_hypos: List[Hypothesis] = []
return new_hypos
def _gen_new_hypos(
self,
base_hypos: List[Hypothesis],
tokens: List[int],
scores: List[float],
t: int,
device: torch.device,
) -> List[Hypothesis]:
tgt_tokens = torch.tensor([[token] for token in tokens], device=device)
states = _batch_state(base_hypos)
pred_out, _, pred_states = self.model.predict(
tgt_tokens, torch.tensor([1] * len(base_hypos), device=device), states,
)
new_hypos: List[Hypothesis] = []
for i, h_a in enumerate(base_hypos):
new_tokens = h_a.tokens + [tokens[i]]
new_hypos.append(
Hypothesis(
tokens=new_tokens,
predictor_out=pred_out[i].detach(),
state=_slice_state(pred_states, i, device),
score=scores[i],
alignment=h_a.alignment + [t],
blank=self.blank,
key=str(new_tokens),
)
)
return new_hypos
def _search(
self, enc_out: torch.Tensor, hypo: Optional[Hypothesis], beam_width: int,
) -> List[Hypothesis]:
n_time_steps = enc_out.shape[1]
device = enc_out.device
a_hypos: List[Hypothesis] = []
b_hypos = self._init_b_hypos(hypo, device)
for t in range(n_time_steps):
a_hypos = b_hypos
b_hypos = torch.jit.annotate(List[Hypothesis], [])
key_to_b_hypo: Dict[str, Hypothesis] = {}
symbols_current_t = 0
while a_hypos:
next_token_probs = self._gen_next_token_probs(
enc_out[:, t: t + 1], a_hypos, device
)
next_token_probs = next_token_probs.cpu()
b_hypos = self._gen_b_hypos(
b_hypos, a_hypos, next_token_probs, key_to_b_hypo,
)
if symbols_current_t == self.step_max_tokens:
break
a_hypos = self._gen_a_hypos(
a_hypos, b_hypos, next_token_probs, t, beam_width, device,
)
if a_hypos:
symbols_current_t += 1
_, sorted_idx = torch.tensor(
[self.hypo_sort_key(hypo) for hypo in b_hypos]
).topk(beam_width)
b_hypos = [b_hypos[idx] for idx in sorted_idx]
return b_hypos
def forward(
self, input: torch.Tensor, length: torch.Tensor, beam_width: int
) -> List[Hypothesis]:
r"""Performs beam search for the given input sequence.
T: number of frames;
D: feature dimension of each frame.
Args:
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
length (torch.Tensor): number of valid frames in input
sequence, with shape () or (1,).
beam_width (int): beam size to use during search.
Returns:
List[Hypothesis]: top-``beam_width`` hypotheses found by beam search.
"""
assert input.dim() == 2 or (
input.dim() == 3 and input.shape[0] == 1
), "input must be of shape (T, D) or (1, T, D)"
if input.dim() == 2:
input = input.unsqueeze(0)
assert length.shape == () or length.shape == (
1,
), "length must be of shape () or (1,)"
if input.dim() == 0:
input = input.unsqueeze(0)
enc_out, _ = self.model.transcribe(input, length)
return self._search(enc_out, None, beam_width)
@torch.jit.export
def infer(
self,
input: torch.Tensor,
length: torch.Tensor,
beam_width: int,
state: Optional[List[List[torch.Tensor]]] = None,
hypothesis: Optional[Hypothesis] = None,
) -> Tuple[List[Hypothesis], List[List[torch.Tensor]]]:
r"""Performs beam search for the given input sequence in streaming mode.
T: number of frames;
D: feature dimension of each frame.
Args:
input (torch.Tensor): sequence of input frames, with shape (T, D) or (1, T, D).
length (torch.Tensor): number of valid frames in input
sequence, with shape () or (1,).
beam_width (int): beam size to use during search.
state (List[List[torch.Tensor]] or None, optional): list of lists of tensors
representing transcription network internal state generated in preceding
invocation. (Default: ``None``)
hypothesis (Hypothesis or None): hypothesis from preceding invocation to seed
search with. (Default: ``None``)
Returns:
(List[Hypothesis], List[List[torch.Tensor]]):
List[Hypothesis]
top-``beam_width`` hypotheses found by beam search.
List[List[torch.Tensor]]
list of lists of tensors representing transcription network
internal state generated in current invocation.
"""
assert input.dim() == 2 or (
input.dim() == 3 and input.shape[0] == 1
), "input must be of shape (T, D) or (1, T, D)"
if input.dim() == 2:
input = input.unsqueeze(0)
assert length.shape == () or length.shape == (
1,
), "length must be of shape () or (1,)"
if input.dim() == 0:
input = input.unsqueeze(0)
enc_out, _, state = self.model.transcribe_streaming(input, length, state)
return self._search(enc_out, hypothesis, beam_width), state
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment