Commit 60a2c57a authored by sunzhq2's avatar sunzhq2 Committed by xuxo
Browse files

update conformer

parent 4a699441
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
# This code is ported from the following implementation written in Torch.
# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
import logging
import os
import random
import chainer
import h5py
import numpy as np
from chainer.training import extension
from tqdm import tqdm
def load_dataset(path, label_dict, outdir=None):
"""Load and save HDF5 that contains a dataset and stats for LM
Args:
path (str): The path of an input text dataset file
label_dict (dict[str, int]):
dictionary that maps token label string to its ID number
outdir (str): The path of an output dir
Returns:
tuple[list[np.ndarray], int, int]: Tuple of
token IDs in np.int32 converted by `read_tokens`
the number of tokens by `count_tokens`,
and the number of OOVs by `count_tokens`
"""
if outdir is not None:
os.makedirs(outdir, exist_ok=True)
filename = outdir + "/" + os.path.basename(path) + ".h5"
if os.path.exists(filename):
logging.info(f"loading binary dataset: {filename}")
f = h5py.File(filename, "r")
return f["data"][:], f["n_tokens"][()], f["n_oovs"][()]
else:
logging.info("skip dump/load HDF5 because the output dir is not specified")
logging.info(f"reading text dataset: {path}")
ret = read_tokens(path, label_dict)
n_tokens, n_oovs = count_tokens(ret, label_dict["<unk>"])
if outdir is not None:
logging.info(f"saving binary dataset: {filename}")
with h5py.File(filename, "w") as f:
# http://docs.h5py.org/en/stable/special.html#arbitrary-vlen-data
data = f.create_dataset(
"data", (len(ret),), dtype=h5py.special_dtype(vlen=np.int32)
)
data[:] = ret
f["n_tokens"] = n_tokens
f["n_oovs"] = n_oovs
return ret, n_tokens, n_oovs
def read_tokens(filename, label_dict):
"""Read tokens as a sequence of sentences
:param str filename : The name of the input file
:param dict label_dict : dictionary that maps token label string to its ID number
:return list of ID sequences
:rtype list
"""
data = []
unk = label_dict["<unk>"]
for ln in tqdm(open(filename, "r", encoding="utf-8")):
data.append(
np.array(
[label_dict.get(label, unk) for label in ln.split()], dtype=np.int32
)
)
return data
def count_tokens(data, unk_id=None):
"""Count tokens and oovs in token ID sequences.
Args:
data (list[np.ndarray]): list of token ID sequences
unk_id (int): ID of unknown token
Returns:
tuple: tuple of number of token occurrences and number of oov tokens
"""
n_tokens = 0
n_oovs = 0
for sentence in data:
n_tokens += len(sentence)
if unk_id is not None:
n_oovs += np.count_nonzero(sentence == unk_id)
return n_tokens, n_oovs
def compute_perplexity(result):
"""Computes and add the perplexity to the LogReport
:param dict result: The current observations
"""
# Routine to rewrite the result dictionary of LogReport to add perplexity values
result["perplexity"] = np.exp(result["main/loss"] / result["main/count"])
if "validation/main/loss" in result:
result["val_perplexity"] = np.exp(result["validation/main/loss"])
class ParallelSentenceIterator(chainer.dataset.Iterator):
"""Dataset iterator to create a batch of sentences.
This iterator returns a pair of sentences, where one token is shifted
between the sentences like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
Sentence batches are made in order of longer sentences, and then
randomly shuffled.
"""
def __init__(
self, dataset, batch_size, max_length=0, sos=0, eos=0, repeat=True, shuffle=True
):
self.dataset = dataset
self.batch_size = batch_size # batch size
# Number of completed sweeps over the dataset. In this case, it is
# incremented if every word is visited at least once after the last
# increment.
self.epoch = 0
# True if the epoch is incremented at the last iteration.
self.is_new_epoch = False
self.repeat = repeat
length = len(dataset)
self.batch_indices = []
# make mini-batches
if batch_size > 1:
indices = sorted(range(len(dataset)), key=lambda i: -len(dataset[i]))
bs = 0
while bs < length:
be = min(bs + batch_size, length)
# batch size is automatically reduced if the sentence length
# is larger than max_length
if max_length > 0:
sent_length = len(dataset[indices[bs]])
be = min(
be, bs + max(batch_size // (sent_length // max_length + 1), 1)
)
self.batch_indices.append(np.array(indices[bs:be]))
bs = be
if shuffle:
# shuffle batches
random.shuffle(self.batch_indices)
else:
self.batch_indices = [np.array([i]) for i in range(length)]
# NOTE: this is not a count of parameter updates. It is just a count of
# calls of ``__next__``.
self.iteration = 0
self.sos = sos
self.eos = eos
# use -1 instead of None internally
self._previous_epoch_detail = -1.0
def __next__(self):
# This iterator returns a list representing a mini-batch. Each item
# indicates a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
# represented by token IDs.
n_batches = len(self.batch_indices)
if not self.repeat and self.iteration >= n_batches:
# If not self.repeat, this iterator stops at the end of the first
# epoch (i.e., when all words are visited once).
raise StopIteration
batch = []
for idx in self.batch_indices[self.iteration % n_batches]:
batch.append(
(
np.append([self.sos], self.dataset[idx]),
np.append(self.dataset[idx], [self.eos]),
)
)
self._previous_epoch_detail = self.epoch_detail
self.iteration += 1
epoch = self.iteration // n_batches
self.is_new_epoch = self.epoch < epoch
if self.is_new_epoch:
self.epoch = epoch
return batch
def start_shuffle(self):
random.shuffle(self.batch_indices)
@property
def epoch_detail(self):
# Floating point version of epoch.
return self.iteration / len(self.batch_indices)
@property
def previous_epoch_detail(self):
if self._previous_epoch_detail < 0:
return None
return self._previous_epoch_detail
def serialize(self, serializer):
# It is important to serialize the state to be recovered on resume.
self.iteration = serializer("iteration", self.iteration)
self.epoch = serializer("epoch", self.epoch)
try:
self._previous_epoch_detail = serializer(
"previous_epoch_detail", self._previous_epoch_detail
)
except KeyError:
# guess previous_epoch_detail for older version
self._previous_epoch_detail = self.epoch + (
self.current_position - 1
) / len(self.batch_indices)
if self.epoch_detail > 0:
self._previous_epoch_detail = max(self._previous_epoch_detail, 0.0)
else:
self._previous_epoch_detail = -1.0
class MakeSymlinkToBestModel(extension.Extension):
"""Extension that makes a symbolic link to the best model
:param str key: Key of value
:param str prefix: Prefix of model files and link target
:param str suffix: Suffix of link target
"""
def __init__(self, key, prefix="model", suffix="best"):
super(MakeSymlinkToBestModel, self).__init__()
self.best_model = -1
self.min_loss = 0.0
self.key = key
self.prefix = prefix
self.suffix = suffix
def __call__(self, trainer):
observation = trainer.observation
if self.key in observation:
loss = observation[self.key]
if self.best_model == -1 or loss < self.min_loss:
self.min_loss = loss
self.best_model = trainer.updater.epoch
src = "%s.%d" % (self.prefix, self.best_model)
dest = os.path.join(trainer.out, "%s.%s" % (self.prefix, self.suffix))
if os.path.lexists(dest):
os.remove(dest)
os.symlink(src, dest)
logging.info("best model is " + src)
def serialize(self, serializer):
if isinstance(serializer, chainer.serializer.Serializer):
serializer("_best_model", self.best_model)
serializer("_min_loss", self.min_loss)
serializer("_key", self.key)
serializer("_prefix", self.prefix)
serializer("_suffix", self.suffix)
else:
self.best_model = serializer("_best_model", -1)
self.min_loss = serializer("_min_loss", 0.0)
self.key = serializer("_key", "")
self.prefix = serializer("_prefix", "model")
self.suffix = serializer("_suffix", "best")
# TODO(Hori): currently it only works with character-word level LM.
# need to consider any types of subwords-to-word mapping.
def make_lexical_tree(word_dict, subword_dict, word_unk):
"""Make a lexical tree to compute word-level probabilities"""
# node [dict(subword_id -> node), word_id, word_set[start-1, end]]
root = [{}, -1, None]
for w, wid in word_dict.items():
if wid > 0 and wid != word_unk: # skip <blank> and <unk>
if True in [c not in subword_dict for c in w]: # skip unknown subword
continue
succ = root[0] # get successors from root node
for i, c in enumerate(w):
cid = subword_dict[c]
if cid not in succ: # if next node does not exist, make a new node
succ[cid] = [{}, -1, (wid - 1, wid)]
else:
prev = succ[cid][2]
succ[cid][2] = (min(prev[0], wid - 1), max(prev[1], wid))
if i == len(w) - 1: # if word end, set word id
succ[cid][1] = wid
succ = succ[cid][0] # move to the child successors
return root
#!/usr/bin/env python3
# Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from espnet.lm.lm_utils import make_lexical_tree
from espnet.nets.pytorch_backend.nets_utils import to_device
# Definition of a multi-level (subword/word) language model
class MultiLevelLM(nn.Module):
logzero = -10000000000.0
zero = 1.0e-10
def __init__(
self,
wordlm,
subwordlm,
word_dict,
subword_dict,
subwordlm_weight=0.8,
oov_penalty=1.0,
open_vocab=True,
):
super(MultiLevelLM, self).__init__()
self.wordlm = wordlm
self.subwordlm = subwordlm
self.word_eos = word_dict["<eos>"]
self.word_unk = word_dict["<unk>"]
self.var_word_eos = torch.LongTensor([self.word_eos])
self.var_word_unk = torch.LongTensor([self.word_unk])
self.space = subword_dict["<space>"]
self.eos = subword_dict["<eos>"]
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
self.log_oov_penalty = math.log(oov_penalty)
self.open_vocab = open_vocab
self.subword_dict_size = len(subword_dict)
self.subwordlm_weight = subwordlm_weight
self.normalized = True
def forward(self, state, x):
# update state with input label x
if state is None: # make initial states and log-prob vectors
self.var_word_eos = to_device(x, self.var_word_eos)
self.var_word_unk = to_device(x, self.var_word_eos)
wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
wlm_logprobs = F.log_softmax(z_wlm, dim=1)
clm_state, z_clm = self.subwordlm(None, x)
log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight
new_node = self.lexroot
clm_logprob = 0.0
xi = self.space
else:
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
xi = int(x)
if xi == self.space: # inter-word transition
if node is not None and node[1] >= 0: # check if the node is word end
w = to_device(x, torch.LongTensor([node[1]]))
else: # this node is not a word end, which means <unk>
w = self.var_word_unk
# update wordlm state and log-prob vector
wlm_state, z_wlm = self.wordlm(wlm_state, w)
wlm_logprobs = F.log_softmax(z_wlm, dim=1)
new_node = self.lexroot # move to the tree root
clm_logprob = 0.0
elif node is not None and xi in node[0]: # intra-word transition
new_node = node[0][xi]
clm_logprob += log_y[0, xi]
elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
new_node = None
clm_logprob += log_y[0, xi]
else: # if open_vocab flag is disabled, return 0 probabilities
log_y = to_device(
x, torch.full((1, self.subword_dict_size), self.logzero)
)
return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.0), log_y
clm_state, z_clm = self.subwordlm(clm_state, x)
log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight
# apply word-level probabilies for <space> and <eos> labels
if xi != self.space:
if new_node is not None and new_node[1] >= 0: # if new node is word end
wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob
else:
wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty
log_y[:, self.space] = wlm_logprob
log_y[:, self.eos] = wlm_logprob
else:
log_y[:, self.space] = self.logzero
log_y[:, self.eos] = self.logzero
return (
(clm_state, wlm_state, wlm_logprobs, new_node, log_y, float(clm_logprob)),
log_y,
)
def final(self, state):
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
if node is not None and node[1] >= 0: # check if the node is word end
w = to_device(wlm_logprobs, torch.LongTensor([node[1]]))
else: # this node is not a word end, which means <unk>
w = self.var_word_unk
wlm_state, z_wlm = self.wordlm(wlm_state, w)
return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])
# Definition of a look-ahead word language model
class LookAheadWordLM(nn.Module):
logzero = -10000000000.0
zero = 1.0e-10
def __init__(
self, wordlm, word_dict, subword_dict, oov_penalty=0.0001, open_vocab=True
):
super(LookAheadWordLM, self).__init__()
self.wordlm = wordlm
self.word_eos = word_dict["<eos>"]
self.word_unk = word_dict["<unk>"]
self.var_word_eos = torch.LongTensor([self.word_eos])
self.var_word_unk = torch.LongTensor([self.word_unk])
self.space = subword_dict["<space>"]
self.eos = subword_dict["<eos>"]
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
self.oov_penalty = oov_penalty
self.open_vocab = open_vocab
self.subword_dict_size = len(subword_dict)
self.zero_tensor = torch.FloatTensor([self.zero])
self.normalized = True
def forward(self, state, x):
# update state with input label x
if state is None: # make initial states and cumlative probability vector
self.var_word_eos = to_device(x, self.var_word_eos)
self.var_word_unk = to_device(x, self.var_word_eos)
self.zero_tensor = to_device(x, self.zero_tensor)
wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
new_node = self.lexroot
xi = self.space
else:
wlm_state, cumsum_probs, node = state
xi = int(x)
if xi == self.space: # inter-word transition
if node is not None and node[1] >= 0: # check if the node is word end
w = to_device(x, torch.LongTensor([node[1]]))
else: # this node is not a word end, which means <unk>
w = self.var_word_unk
# update wordlm state and cumlative probability vector
wlm_state, z_wlm = self.wordlm(wlm_state, w)
cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
new_node = self.lexroot # move to the tree root
elif node is not None and xi in node[0]: # intra-word transition
new_node = node[0][xi]
elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
new_node = None
else: # if open_vocab flag is disabled, return 0 probabilities
log_y = to_device(
x, torch.full((1, self.subword_dict_size), self.logzero)
)
return (wlm_state, None, None), log_y
if new_node is not None:
succ, wid, wids = new_node
# compute parent node probability
sum_prob = (
(cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]])
if wids is not None
else 1.0
)
if sum_prob < self.zero:
log_y = to_device(
x, torch.full((1, self.subword_dict_size), self.logzero)
)
return (wlm_state, cumsum_probs, new_node), log_y
# set <unk> probability as a default value
unk_prob = (
cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1]
)
y = to_device(
x,
torch.full(
(1, self.subword_dict_size), float(unk_prob) * self.oov_penalty
),
)
# compute transition probabilities to child nodes
for cid, nd in succ.items():
y[:, cid] = (
cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]]
) / sum_prob
# apply word-level probabilies for <space> and <eos> labels
if wid >= 0:
wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob
y[:, self.space] = wlm_prob
y[:, self.eos] = wlm_prob
elif xi == self.space:
y[:, self.space] = self.zero
y[:, self.eos] = self.zero
log_y = torch.log(torch.max(y, self.zero_tensor)) # clip to avoid log(0)
else: # if no path in the tree, transition probability is one
log_y = to_device(x, torch.zeros(1, self.subword_dict_size))
return (wlm_state, cumsum_probs, new_node), log_y
def final(self, state):
wlm_state, cumsum_probs, node = state
if node is not None and node[1] >= 0: # check if the node is word end
w = to_device(cumsum_probs, torch.LongTensor([node[1]]))
else: # this node is not a word end, which means <unk>
w = self.var_word_unk
wlm_state, z_wlm = self.wordlm(wlm_state, w)
return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
# This code is ported from the following implementation written in Torch.
# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
"""LM training in pytorch."""
import copy
import json
import logging
import numpy as np
import torch
import torch.nn as nn
from chainer import Chain, reporter, training
from chainer.dataset import convert
from chainer.training import extensions
from torch.nn.parallel import data_parallel
from espnet.asr.asr_utils import (
snapshot_object,
torch_load,
torch_resume,
torch_snapshot,
)
from espnet.lm.lm_utils import (
MakeSymlinkToBestModel,
ParallelSentenceIterator,
count_tokens,
load_dataset,
read_tokens,
)
from espnet.nets.lm_interface import LMInterface, dynamic_import_lm
from espnet.optimizer.factory import dynamic_import_optimizer
from espnet.scheduler.pytorch import PyTorchScheduler
from espnet.scheduler.scheduler import dynamic_import_scheduler
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop, set_early_stop
def compute_perplexity(result):
"""Compute and add the perplexity to the LogReport.
:param dict result: The current observations
"""
# Routine to rewrite the result dictionary of LogReport to add perplexity values
result["perplexity"] = np.exp(result["main/nll"] / result["main/count"])
if "validation/main/nll" in result:
result["val_perplexity"] = np.exp(
result["validation/main/nll"] / result["validation/main/count"]
)
class Reporter(Chain):
"""Dummy module to use chainer's trainer."""
def report(self, loss):
"""Report nothing."""
pass
def concat_examples(batch, device=None, padding=None):
"""Concat examples in minibatch.
:param np.ndarray batch: The batch to concatenate
:param int device: The device to send to
:param Tuple[int,int] padding: The padding to use
:return: (inputs, targets)
:rtype (torch.Tensor, torch.Tensor)
"""
x, t = convert.concat_examples(batch, padding=padding)
x = torch.from_numpy(x)
t = torch.from_numpy(t)
if device is not None and device >= 0:
x = x.cuda(device)
t = t.cuda(device)
return x, t
class BPTTUpdater(training.StandardUpdater):
"""An updater for a pytorch LM."""
def __init__(
self,
train_iter,
model,
optimizer,
schedulers,
device,
gradclip=None,
use_apex=False,
accum_grad=1,
):
"""Initialize class.
Args:
train_iter (chainer.dataset.Iterator): The train iterator
model (LMInterface) : The model to update
optimizer (torch.optim.Optimizer): The optimizer for training
schedulers (espnet.scheduler.scheduler.SchedulerInterface):
The schedulers of `optimizer`
device (int): The device id
gradclip (float): The gradient clipping value to use
use_apex (bool): The flag to use Apex in backprop.
accum_grad (int): The number of gradient accumulation.
"""
super(BPTTUpdater, self).__init__(train_iter, optimizer)
self.model = model
self.device = device
self.gradclip = gradclip
self.use_apex = use_apex
self.scheduler = PyTorchScheduler(schedulers, optimizer)
self.accum_grad = accum_grad
# The core part of the update routine can be customized by overriding.
def update_core(self):
"""Update the model."""
# When we pass one iterator and optimizer to StandardUpdater.__init__,
# they are automatically named 'main'.
train_iter = self.get_iterator("main")
optimizer = self.get_optimizer("main")
# Progress the dataset iterator for sentences at each iteration.
self.model.zero_grad() # Clear the parameter gradients
accum = {"loss": 0.0, "nll": 0.0, "count": 0}
for _ in range(self.accum_grad):
batch = train_iter.__next__()
# Concatenate the token IDs to matrices and send them to the device
# self.converter does this job
# (it is chainer.dataset.concat_examples by default)
x, t = concat_examples(batch, device=self.device[0], padding=(0, -100))
if self.device[0] == -1:
loss, nll, count = self.model(x, t)
else:
# apex does not support torch.nn.DataParallel
loss, nll, count = data_parallel(self.model, (x, t), self.device)
# backward
loss = loss.mean() / self.accum_grad
if self.use_apex:
from apex import amp
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward() # Backprop
# accumulate stats
accum["loss"] += float(loss)
accum["nll"] += float(nll.sum())
accum["count"] += int(count.sum())
for k, v in accum.items():
reporter.report({k: v}, optimizer.target)
if self.gradclip is not None:
nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip)
optimizer.step() # Update the parameters
self.scheduler.step(n_iter=self.iteration)
class LMEvaluator(BaseEvaluator):
"""A custom evaluator for a pytorch LM."""
def __init__(self, val_iter, eval_model, reporter, device):
"""Initialize class.
:param chainer.dataset.Iterator val_iter : The validation iterator
:param LMInterface eval_model : The model to evaluate
:param chainer.Reporter reporter : The observations reporter
:param int device : The device id to use
"""
super(LMEvaluator, self).__init__(val_iter, reporter, device=-1)
self.model = eval_model
self.device = device
def evaluate(self):
"""Evaluate the model."""
val_iter = self.get_iterator("main")
loss = 0
nll = 0
count = 0
self.model.eval()
with torch.no_grad():
for batch in copy.copy(val_iter):
x, t = concat_examples(batch, device=self.device[0], padding=(0, -100))
if self.device[0] == -1:
l, n, c = self.model(x, t)
else:
# apex does not support torch.nn.DataParallel
l, n, c = data_parallel(self.model, (x, t), self.device)
loss += float(l.sum())
nll += float(n.sum())
count += int(c.sum())
self.model.train()
# report validation loss
observation = {}
with reporter.report_scope(observation):
reporter.report({"loss": loss}, self.model.reporter)
reporter.report({"nll": nll}, self.model.reporter)
reporter.report({"count": count}, self.model.reporter)
return observation
def train(args):
"""Train with the given args.
:param Namespace args: The program arguments
:param type model_class: LMInterface class for training
"""
model_class = dynamic_import_lm(args.model_module, args.backend)
assert issubclass(model_class, LMInterface), "model should implement LMInterface"
# display torch version
logging.info("torch version = " + torch.__version__)
set_deterministic_pytorch(args)
# check cuda and cudnn availability
if not torch.cuda.is_available():
logging.warning("cuda is not available")
# get special label ids
unk = args.char_list_dict["<unk>"]
eos = args.char_list_dict["<eos>"]
# read tokens as a sequence of sentences
val, n_val_tokens, n_val_oovs = load_dataset(
args.valid_label, args.char_list_dict, args.dump_hdf5_path
)
train, n_train_tokens, n_train_oovs = load_dataset(
args.train_label, args.char_list_dict, args.dump_hdf5_path
)
logging.info("#vocab = " + str(args.n_vocab))
logging.info("#sentences in the training data = " + str(len(train)))
logging.info("#tokens in the training data = " + str(n_train_tokens))
logging.info(
"oov rate in the training data = %.2f %%"
% (n_train_oovs / n_train_tokens * 100)
)
logging.info("#sentences in the validation data = " + str(len(val)))
logging.info("#tokens in the validation data = " + str(n_val_tokens))
logging.info(
"oov rate in the validation data = %.2f %%" % (n_val_oovs / n_val_tokens * 100)
)
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
# Create the dataset iterators
batch_size = args.batchsize * max(args.ngpu, 1)
if batch_size * args.accum_grad > args.batchsize:
logging.info(
f"batch size is automatically increased "
f"({args.batchsize} -> {batch_size * args.accum_grad})"
)
train_iter = ParallelSentenceIterator(
train,
batch_size,
max_length=args.maxlen,
sos=eos,
eos=eos,
shuffle=not use_sortagrad,
)
val_iter = ParallelSentenceIterator(
val, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
)
epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad)
logging.info("#iterations per epoch = %d" % epoch_iters)
logging.info("#total iterations = " + str(args.epoch * epoch_iters))
# Prepare an RNNLM model
if args.train_dtype in ("float16", "float32", "float64"):
dtype = getattr(torch, args.train_dtype)
else:
dtype = torch.float32
model = model_class(args.n_vocab, args).to(dtype=dtype)
if args.ngpu > 0:
model.to("cuda")
gpu_id = list(range(args.ngpu))
else:
gpu_id = [-1]
# Save model conf to json
model_conf = args.outdir + "/model.json"
with open(model_conf, "wb") as f:
logging.info("writing a model config file to " + model_conf)
f.write(
json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode(
"utf_8"
)
)
logging.warning(
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
sum(p.numel() for p in model.parameters() if p.requires_grad)
* 100.0
/ sum(p.numel() for p in model.parameters()),
)
)
# Set up an optimizer
opt_class = dynamic_import_optimizer(args.opt, args.backend)
optimizer = opt_class.from_args(model.parameters(), args)
if args.schedulers is None:
schedulers = []
else:
schedulers = [dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers]
# setup apex.amp
if args.train_dtype in ("O0", "O1", "O2", "O3"):
try:
from apex import amp
except ImportError as e:
logging.error(
f"You need to install apex for --train-dtype {args.train_dtype}. "
"See https://github.com/NVIDIA/apex#linux"
)
raise e
model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype)
use_apex = True
else:
use_apex = False
# FIXME: TOO DIRTY HACK
reporter = Reporter()
setattr(model, "reporter", reporter)
setattr(optimizer, "target", reporter)
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
updater = BPTTUpdater(
train_iter,
model,
optimizer,
schedulers,
gpu_id,
gradclip=args.gradclip,
use_apex=use_apex,
accum_grad=args.accum_grad,
)
trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.outdir)
trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id))
trainer.extend(
extensions.LogReport(
postprocess=compute_perplexity,
trigger=(args.report_interval_iters, "iteration"),
)
)
trainer.extend(
extensions.PrintReport(
[
"epoch",
"iteration",
"main/loss",
"perplexity",
"val_perplexity",
"elapsed_time",
]
),
trigger=(args.report_interval_iters, "iteration"),
)
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
# Save best models
trainer.extend(torch_snapshot(filename="snapshot.ep.{.updater.epoch}"))
trainer.extend(snapshot_object(model, "rnnlm.model.{.updater.epoch}"))
# T.Hori: MinValueTrigger should be used, but it fails when resuming
trainer.extend(MakeSymlinkToBestModel("validation/main/loss", "rnnlm.model"))
if use_sortagrad:
trainer.extend(
ShufflingEnabler([train_iter]),
trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, "epoch"),
)
if args.resume:
logging.info("resumed from %s" % args.resume)
torch_resume(args.resume, trainer)
set_early_stop(trainer, args, is_lm=True)
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(args.tensorboard_dir)
trainer.extend(
TensorboardLogger(writer), trigger=(args.report_interval_iters, "iteration")
)
trainer.run()
check_early_stop(trainer, args.epoch)
# compute perplexity for test set
if args.test_label:
logging.info("test the best model")
torch_load(args.outdir + "/rnnlm.model.best", model)
test = read_tokens(args.test_label, args.char_list_dict)
n_test_tokens, n_test_oovs = count_tokens(test, unk)
logging.info("#sentences in the test data = " + str(len(test)))
logging.info("#tokens in the test data = " + str(n_test_tokens))
logging.info(
"oov rate in the test data = %.2f %%" % (n_test_oovs / n_test_tokens * 100)
)
test_iter = ParallelSentenceIterator(
test, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False
)
evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id)
result = evaluator()
compute_perplexity(result)
logging.info(f"test perplexity: {result['perplexity']}")
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Utility functions for the text translation task."""
import logging
# * ------------------ recognition related ------------------ *
def parse_hypothesis(hyp, char_list):
"""Parse hypothesis.
:param list hyp: recognition hypothesis
:param list char_list: list of characters
:return: recognition text string
:return: recognition token string
:return: recognition tokenid string
"""
# remove sos and get results
tokenid_as_list = list(map(int, hyp["yseq"][1:]))
token_as_list = [char_list[idx] for idx in tokenid_as_list]
score = float(hyp["score"])
# convert to string
tokenid = " ".join([str(idx) for idx in tokenid_as_list])
token = " ".join(token_as_list)
text = "".join(token_as_list).replace("<space>", " ")
return text, token, tokenid, score
def add_results_to_json(js, nbest_hyps, char_list):
"""Add N-best results to json.
:param dict js: groundtruth utterance dict
:param list nbest_hyps: list of hypothesis
:param list char_list: list of characters
:return: N-best results added utterance dict
"""
# copy old json info
new_js = dict()
if "utt2spk" in js.keys():
new_js["utt2spk"] = js["utt2spk"]
new_js["output"] = []
for n, hyp in enumerate(nbest_hyps, 1):
# parse hypothesis
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
# copy ground-truth
if len(js["output"]) > 0:
out_dic = dict(js["output"][0].items())
else:
out_dic = {"name": ""}
# update name
out_dic["name"] += "[%d]" % n
# add recognition results
out_dic["rec_text"] = rec_text
out_dic["rec_token"] = rec_token
out_dic["rec_tokenid"] = rec_tokenid
out_dic["score"] = score
# add source reference
out_dic["text_src"] = js["output"][1]["text"]
out_dic["token_src"] = js["output"][1]["token"]
out_dic["tokenid_src"] = js["output"][1]["tokenid"]
# add to list of N-best result dicts
new_js["output"].append(out_dic)
# show 1-best result
if n == 1:
if "text" in out_dic.keys():
logging.info("groundtruth: %s" % out_dic["text"])
logging.info("prediction : %s" % out_dic["rec_text"])
logging.info("source : %s" % out_dic["token_src"])
return new_js
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Training/decoding definition for the text translation task."""
import itertools
import json
import logging
import os
import numpy as np
import torch
from chainer import training
from chainer.training import extensions
from espnet.asr.asr_utils import (
CompareValueTrigger,
adadelta_eps_decay,
adam_lr_decay,
add_results_to_json,
restore_snapshot,
snapshot_object,
torch_load,
torch_resume,
torch_snapshot,
)
from espnet.asr.pytorch_backend.asr import (
CustomEvaluator,
CustomUpdater,
load_trained_model,
)
from espnet.nets.mt_interface import MTInterface
from espnet.nets.pytorch_backend.e2e_asr import pad_list
from espnet.utils.dataset import ChainerDataLoader, TransformDataset
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop, set_early_stop
class CustomConverter(object):
"""Custom batch converter for Pytorch."""
def __init__(self):
"""Construct a CustomConverter object."""
self.ignore_id = -1
self.pad = 0
# NOTE: we reserve index:0 for <pad> although this is reserved for a blank class
# in ASR. However,
# blank labels are not used in NMT. To keep the vocabulary size,
# we use index:0 for padding instead of adding one more class.
def __call__(self, batch, device=torch.device("cpu")):
"""Transform a batch and send it to a device.
Args:
batch (list): The batch to transform.
device (torch.device): The device to send to.
Returns:
tuple(torch.Tensor, torch.Tensor, torch.Tensor)
"""
# batch should be located in list
assert len(batch) == 1
xs, ys = batch[0]
# get batch of lengths of input sequences
ilens = np.array([x.shape[0] for x in xs])
# perform padding and convert to tensor
xs_pad = pad_list([torch.from_numpy(x).long() for x in xs], self.pad).to(device)
ilens = torch.from_numpy(ilens).to(device)
ys_pad = pad_list([torch.from_numpy(y).long() for y in ys], self.ignore_id).to(
device
)
return xs_pad, ilens, ys_pad
def train(args):
"""Train with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
# check cuda availability
if not torch.cuda.is_available():
logging.warning("cuda is not available")
# get input and output dimension info
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
utts = list(valid_json.keys())
idim = int(valid_json[utts[0]]["output"][1]["shape"][1])
odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
logging.info("#input dims : " + str(idim))
logging.info("#output dims: " + str(odim))
# specify model architecture
model_class = dynamic_import(args.model_module)
model = model_class(idim, odim, args)
assert isinstance(model, MTInterface)
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + "/model.json"
with open(model_conf, "wb") as f:
logging.info("writing a model config file to " + model_conf)
f.write(
json.dumps(
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
for key in sorted(vars(args).keys()):
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
reporter = model.reporter
# check the use of multi-gpu
if args.ngpu > 1:
if args.batch_size != 0:
logging.warning(
"batch size is automatically increased (%d -> %d)"
% (args.batch_size, args.batch_size * args.ngpu)
)
args.batch_size *= args.ngpu
# set torch device
device = torch.device("cuda" if args.ngpu > 0 else "cpu")
if args.train_dtype in ("float16", "float32", "float64"):
dtype = getattr(torch, args.train_dtype)
else:
dtype = torch.float32
model = model.to(device=device, dtype=dtype)
logging.warning(
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
sum(p.numel() for p in model.parameters() if p.requires_grad)
* 100.0
/ sum(p.numel() for p in model.parameters()),
)
)
# Setup an optimizer
if args.opt == "adadelta":
optimizer = torch.optim.Adadelta(
model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay
)
elif args.opt == "adam":
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
elif args.opt == "noam":
from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
optimizer = get_std_opt(
model.parameters(),
args.adim,
args.transformer_warmup_steps,
args.transformer_lr,
)
else:
raise NotImplementedError("unknown optimizer: " + args.opt)
# setup apex.amp
if args.train_dtype in ("O0", "O1", "O2", "O3"):
try:
from apex import amp
except ImportError as e:
logging.error(
f"You need to install apex for --train-dtype {args.train_dtype}. "
"See https://github.com/NVIDIA/apex#linux"
)
raise e
if args.opt == "noam":
model, optimizer.optimizer = amp.initialize(
model, optimizer.optimizer, opt_level=args.train_dtype
)
else:
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.train_dtype
)
use_apex = True
else:
use_apex = False
# FIXME: TOO DIRTY HACK
setattr(optimizer, "target", reporter)
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
# Setup a converter
converter = CustomConverter()
# read json data
with open(args.train_json, "rb") as f:
train_json = json.load(f)["utts"]
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
# make minibatch list (variable length)
train = make_batchset(
train_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
mt=True,
iaxis=1,
oaxis=0,
)
valid = make_batchset(
valid_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
mt=True,
iaxis=1,
oaxis=0,
)
load_tr = LoadInputsAndTargets(mode="mt", load_output=True)
load_cv = LoadInputsAndTargets(mode="mt", load_output=True)
# hack to make batchsize argument as 1
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
train_iter = ChainerDataLoader(
dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
batch_size=1,
num_workers=args.n_iter_processes,
shuffle=not use_sortagrad,
collate_fn=lambda x: x[0],
)
valid_iter = ChainerDataLoader(
dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
batch_size=1,
shuffle=False,
collate_fn=lambda x: x[0],
num_workers=args.n_iter_processes,
)
# Set up a trainer
updater = CustomUpdater(
model,
args.grad_clip,
{"main": train_iter},
optimizer,
device,
args.ngpu,
False,
args.accum_grad,
use_apex=use_apex,
)
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
if use_sortagrad:
trainer.extend(
ShufflingEnabler([train_iter]),
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
)
# Resume from a snapshot
if args.resume:
logging.info("resumed from %s" % args.resume)
torch_resume(args.resume, trainer)
# Evaluate the model with the test dataset for each epoch
if args.save_interval_iters > 0:
trainer.extend(
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu),
trigger=(args.save_interval_iters, "iteration"),
)
else:
trainer.extend(
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu)
)
# Save attention weight each epoch
if args.num_save_attention > 0:
# NOTE: sort it by output lengths
data = sorted(
list(valid_json.items())[: args.num_save_attention],
key=lambda x: int(x[1]["output"][0]["shape"][0]),
reverse=True,
)
if hasattr(model, "module"):
att_vis_fn = model.module.calculate_all_attentions
plot_class = model.module.attention_plot_class
else:
att_vis_fn = model.calculate_all_attentions
plot_class = model.attention_plot_class
att_reporter = plot_class(
att_vis_fn,
data,
args.outdir + "/att_ws",
converter=converter,
transform=load_cv,
device=device,
ikey="output",
iaxis=1,
)
trainer.extend(att_reporter, trigger=(1, "epoch"))
else:
att_reporter = None
# Make a plot for training and validation values
trainer.extend(
extensions.PlotReport(
["main/loss", "validation/main/loss"], "epoch", file_name="loss.png"
)
)
trainer.extend(
extensions.PlotReport(
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
)
)
trainer.extend(
extensions.PlotReport(
["main/ppl", "validation/main/ppl"], "epoch", file_name="ppl.png"
)
)
trainer.extend(
extensions.PlotReport(
["main/bleu", "validation/main/bleu"], "epoch", file_name="bleu.png"
)
)
# Save best models
trainer.extend(
snapshot_object(model, "model.loss.best"),
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
)
trainer.extend(
snapshot_object(model, "model.acc.best"),
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
)
# save snapshot which contains model and optimizer states
if args.save_interval_iters > 0:
trainer.extend(
torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
trigger=(args.save_interval_iters, "iteration"),
)
else:
trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
# epsilon decay in the optimizer
if args.opt == "adadelta":
if args.criterion == "acc":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.acc.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
elif args.criterion == "loss":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.loss.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
elif args.opt == "adam":
if args.criterion == "acc":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.acc.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
trainer.extend(
adam_lr_decay(args.lr_decay),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
elif args.criterion == "loss":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.loss.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
trainer.extend(
adam_lr_decay(args.lr_decay),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
# Write a log of evaluation statistics for each epoch
trainer.extend(
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
)
report_keys = [
"epoch",
"iteration",
"main/loss",
"validation/main/loss",
"main/acc",
"validation/main/acc",
"main/ppl",
"validation/main/ppl",
"elapsed_time",
]
if args.opt == "adadelta":
trainer.extend(
extensions.observe_value(
"eps",
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
"eps"
],
),
trigger=(args.report_interval_iters, "iteration"),
)
report_keys.append("eps")
elif args.opt in ["adam", "noam"]:
trainer.extend(
extensions.observe_value(
"lr",
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
"lr"
],
),
trigger=(args.report_interval_iters, "iteration"),
)
report_keys.append("lr")
if args.report_bleu:
report_keys.append("main/bleu")
report_keys.append("validation/main/bleu")
trainer.extend(
extensions.PrintReport(report_keys),
trigger=(args.report_interval_iters, "iteration"),
)
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
set_early_stop(trainer, args)
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
from torch.utils.tensorboard import SummaryWriter
trainer.extend(
TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter),
trigger=(args.report_interval_iters, "iteration"),
)
# Run the training
trainer.run()
check_early_stop(trainer, args.epochs)
def trans(args):
"""Decode with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
assert isinstance(model, MTInterface)
model.trans_args = args
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info("gpu id: " + str(gpu_id))
model.cuda()
# read json data
with open(args.trans_json, "rb") as f:
js = json.load(f)["utts"]
new_js = {}
# remove enmpy utterances
if train_args.multilingual:
js = {
k: v
for k, v in js.items()
if v["output"][0]["shape"][0] > 1 and v["output"][1]["shape"][0] > 1
}
else:
js = {
k: v
for k, v in js.items()
if v["output"][0]["shape"][0] > 0 and v["output"][1]["shape"][0] > 0
}
if args.batchsize == 0:
with torch.no_grad():
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
feat = [js[name]["output"][1]["tokenid"].split()]
nbest_hyps = model.translate(feat, args, train_args.char_list)
new_js[name] = add_results_to_json(
js[name], nbest_hyps, train_args.char_list
)
else:
def grouper(n, iterable, fillvalue=None):
kargs = [iter(iterable)] * n
return itertools.zip_longest(*kargs, fillvalue=fillvalue)
# sort data
keys = list(js.keys())
feat_lens = [js[key]["output"][1]["shape"][0] for key in keys]
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
keys = [keys[i] for i in sorted_index]
with torch.no_grad():
for names in grouper(args.batchsize, keys, None):
names = [name for name in names if name]
feats = [
np.fromiter(
map(int, js[name]["output"][1]["tokenid"].split()),
dtype=np.int64,
)
for name in names
]
nbest_hyps = model.translate_batch(
feats,
args,
train_args.char_list,
)
for i, nbest_hyp in enumerate(nbest_hyps):
name = names[i]
new_js[name] = add_results_to_json(
js[name], nbest_hyp, train_args.char_list
)
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
"""ASR Interface module."""
import argparse
from espnet.bin.asr_train import get_parser
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.fill_missing_args import fill_missing_args
class ASRInterface:
"""ASR Interface for ESPnet model implementation."""
@staticmethod
def add_arguments(parser):
"""Add arguments to parser."""
return parser
@classmethod
def build(cls, idim: int, odim: int, **kwargs):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
def wrap(parser):
return get_parser(parser, required=False)
args = argparse.Namespace(**kwargs)
args = fill_missing_args(args, wrap)
args = fill_missing_args(args, cls.add_arguments)
return cls(idim, odim, args)
def forward(self, xs, ilens, ys):
"""Compute loss for training.
:param xs:
For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim)
For chainer, list of source sequences chainer.Variable
:param ilens: batch of lengths of source sequences (B)
For pytorch, torch.Tensor
For chainer, list of int
:param ys:
For pytorch, batch of padded source sequences torch.Tensor (B, Lmax)
For chainer, list of source sequences chainer.Variable
:return: loss value
:rtype: torch.Tensor for pytorch, chainer.Variable for chainer
"""
raise NotImplementedError("forward method is not implemented")
def recognize(self, x, recog_args, char_list=None, rnnlm=None):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace recog_args: argment namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("recognize method is not implemented")
def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace recog_args: argument namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
def calculate_all_attentions(self, xs, ilens, ys):
"""Calculate attention.
:param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise NotImplementedError("calculate_all_attentions method is not implemented")
def calculate_all_ctc_probs(self, xs, ilens, ys):
"""Calculate CTC probability.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: CTC probabilities (B, Tmax, vocab)
:rtype: float ndarray
"""
raise NotImplementedError("calculate_all_ctc_probs method is not implemented")
@property
def attention_plot_class(self):
"""Get attention plot class."""
from espnet.asr.asr_utils import PlotAttentionReport
return PlotAttentionReport
@property
def ctc_plot_class(self):
"""Get CTC plot class."""
from espnet.asr.asr_utils import PlotCTCReport
return PlotCTCReport
def get_total_subsampling_factor(self):
"""Get total subsampling factor."""
raise NotImplementedError(
"get_total_subsampling_factor method is not implemented"
)
def encode(self, feat):
"""Encode feature in `beam_search` (optional).
Args:
x (numpy.ndarray): input feature (T, D)
Returns:
torch.Tensor for pytorch, chainer.Variable for chainer:
encoded feature (T, D)
"""
raise NotImplementedError("encode method is not implemented")
def scorers(self):
"""Get scorers for `beam_search` (optional).
Returns:
dict[str, ScorerInterface]: dict of `ScorerInterface` objects
"""
raise NotImplementedError("decoders method is not implemented")
predefined_asr = {
"pytorch": {
"rnn": "espnet.nets.pytorch_backend.e2e_asr:E2E",
"transducer": "espnet.nets.pytorch_backend.e2e_asr_transducer:E2E",
"transformer": "espnet.nets.pytorch_backend.e2e_asr_transformer:E2E",
"conformer": "espnet.nets.pytorch_backend.e2e_asr_conformer:E2E",
},
"chainer": {
"rnn": "espnet.nets.chainer_backend.e2e_asr:E2E",
"transformer": "espnet.nets.chainer_backend.e2e_asr_transformer:E2E",
},
}
def dynamic_import_asr(module, backend):
"""Import ASR models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_asr`
backend (str): NN backend. e.g., pytorch, chainer
Returns:
type: ASR class
"""
model_class = dynamic_import(module, predefined_asr.get(backend, dict()))
assert issubclass(
model_class, ASRInterface
), f"{module} does not implement ASRInterface"
return model_class
"""Parallel beam search module."""
import logging
from typing import Any, Dict, List, NamedTuple, Tuple
import torch
from packaging.version import parse as V
from torch.nn.utils.rnn import pad_sequence
from espnet.nets.beam_search import BeamSearch, Hypothesis
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
class BatchHypothesis(NamedTuple):
"""Batchfied/Vectorized hypothesis data type."""
yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen)
score: torch.Tensor = torch.tensor([]) # (batch,)
length: torch.Tensor = torch.tensor([]) # (batch,)
scores: Dict[str, torch.Tensor] = dict() # values: (batch,)
states: Dict[str, Dict] = dict()
def __len__(self) -> int:
"""Return a batch size."""
return len(self.length)
class BatchBeamSearch(BeamSearch):
"""Batch beam search implementation."""
def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis:
"""Convert list to batch."""
if len(hyps) == 0:
return BatchHypothesis()
return BatchHypothesis(
yseq=pad_sequence(
[h.yseq for h in hyps], batch_first=True, padding_value=self.eos
),
length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64),
score=torch.tensor([h.score for h in hyps]),
scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers},
states={k: [h.states[k] for h in hyps] for k in self.scorers},
)
def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis:
return BatchHypothesis(
yseq=hyps.yseq[ids],
score=hyps.score[ids],
length=hyps.length[ids],
scores={k: v[ids] for k, v in hyps.scores.items()},
states={
k: [self.scorers[k].select_state(v, i) for i in ids]
for k, v in hyps.states.items()
},
)
def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis:
return Hypothesis(
yseq=hyps.yseq[i, : hyps.length[i]],
score=hyps.score[i],
scores={k: v[i] for k, v in hyps.scores.items()},
states={
k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items()
},
)
def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]:
"""Revert batch to list."""
return [
Hypothesis(
yseq=batch_hyps.yseq[i][: batch_hyps.length[i]],
score=batch_hyps.score[i],
scores={k: batch_hyps.scores[k][i] for k in self.scorers},
states={
k: v.select_state(batch_hyps.states[k], i)
for k, v in self.scorers.items()
},
)
for i in range(len(batch_hyps.length))
]
def batch_beam(
self, weighted_scores: torch.Tensor, ids: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Batch-compute topk full token ids and partial token ids.
Args:
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
Its shape is `(n_beam, self.vocab_size)`.
ids (torch.Tensor): The partial token ids to compute topk.
Its shape is `(n_beam, self.pre_beam_size)`.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
The topk full (prev_hyp, new_token) ids
and partial (prev_hyp, new_token) ids.
Their shapes are all `(self.beam_size,)`
"""
top_ids = weighted_scores.view(-1).topk(self.beam_size)[1]
# Because of the flatten above, `top_ids` is organized as:
# [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
# where V is `self.n_vocab` and K is `self.beam_size`
if is_torch_1_9_plus:
prev_hyp_ids = torch.div(top_ids, self.n_vocab, rounding_mode="trunc")
else:
prev_hyp_ids = top_ids // self.n_vocab
new_token_ids = top_ids % self.n_vocab
return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids
def init_hyp(self, x: torch.Tensor) -> BatchHypothesis:
"""Get an initial hypothesis data.
Args:
x (torch.Tensor): The encoder output feature
Returns:
Hypothesis: The initial hypothesis.
"""
init_states = dict()
init_scores = dict()
for k, d in self.scorers.items():
init_states[k] = d.batch_init_state(x)
init_scores[k] = 0.0
# NOTE (Shih-Lun): added for OpenAI Whisper ASR
primer = [self.sos] if self.hyp_primer is None else self.hyp_primer
return self.batchfy(
[
Hypothesis(
score=0.0,
scores=init_scores,
states=init_states,
yseq=torch.tensor(primer, device=x.device),
)
]
)
def score_full(
self, hyp: BatchHypothesis, x: torch.Tensor
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.full_scorers`
and tensor score values of shape: `(self.n_vocab,)`,
and state dict that has string keys
and state values of `self.full_scorers`
"""
scores = dict()
states = dict()
for k, d in self.full_scorers.items():
scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
return scores, states
def score_partial(
self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
ids (torch.Tensor): 2D tensor of new partial tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.full_scorers`
and tensor score values of shape: `(self.n_vocab,)`,
and state dict that has string keys
and state values of `self.full_scorers`
"""
scores = dict()
states = dict()
for k, d in self.part_scorers.items():
scores[k], states[k] = d.batch_score_partial(
hyp.yseq, ids, hyp.states[k], x
)
return scores, states
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
"""Merge states for new hypothesis.
Args:
states: states of `self.full_scorers`
part_states: states of `self.part_scorers`
part_idx (int): The new token id for `part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are states of the scorers.
"""
new_states = dict()
for k, v in states.items():
new_states[k] = v
for k, v in part_states.items():
new_states[k] = v
return new_states
def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis:
"""Search new tokens for running hypotheses and encoded speech x.
Args:
running_hyps (BatchHypothesis): Running hypotheses on beam
x (torch.Tensor): Encoded speech feature (T, D)
Returns:
BatchHypothesis: Best sorted hypotheses
"""
n_batch = len(running_hyps)
part_ids = None # no pre-beam
# batch scoring
weighted_scores = torch.zeros(
n_batch, self.n_vocab, dtype=x.dtype, device=x.device
)
scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape))
for k in self.full_scorers:
weighted_scores += self.weights[k] * scores[k]
# partial scoring
if self.do_pre_beam:
pre_beam_scores = (
weighted_scores
if self.pre_beam_score_key == "full"
else scores[self.pre_beam_score_key]
)
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1]
# NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
# full-size score matrices, which has non-zero scores for part_ids and zeros
# for others.
part_scores, part_states = self.score_partial(running_hyps, part_ids, x)
for k in self.part_scorers:
weighted_scores += self.weights[k] * part_scores[k]
# add previous hyp scores
weighted_scores += running_hyps.score.to(
dtype=x.dtype, device=x.device
).unsqueeze(1)
# TODO(karita): do not use list. use batch instead
# see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
# update hyps
best_hyps = []
prev_hyps = self.unbatchfy(running_hyps)
for (
full_prev_hyp_id,
full_new_token_id,
part_prev_hyp_id,
part_new_token_id,
) in zip(*self.batch_beam(weighted_scores, part_ids)):
prev_hyp = prev_hyps[full_prev_hyp_id]
best_hyps.append(
Hypothesis(
score=weighted_scores[full_prev_hyp_id, full_new_token_id],
yseq=self.append_token(prev_hyp.yseq, full_new_token_id),
scores=self.merge_scores(
prev_hyp.scores,
{k: v[full_prev_hyp_id] for k, v in scores.items()},
full_new_token_id,
{k: v[part_prev_hyp_id] for k, v in part_scores.items()},
part_new_token_id,
),
states=self.merge_states(
{
k: self.full_scorers[k].select_state(v, full_prev_hyp_id)
for k, v in states.items()
},
{
k: self.part_scorers[k].select_state(
v, part_prev_hyp_id, part_new_token_id
)
for k, v in part_states.items()
},
part_new_token_id,
),
)
)
return self.batchfy(best_hyps)
def post_process(
self,
i: int,
maxlen: int,
maxlenratio: float,
running_hyps: BatchHypothesis,
ended_hyps: List[Hypothesis],
) -> BatchHypothesis:
"""Perform post-processing of beam search iterations.
Args:
i (int): The length of hypothesis tokens.
maxlen (int): The maximum length of tokens in beam search.
maxlenratio (int): The maximum length ratio in beam search.
running_hyps (BatchHypothesis): The running hypotheses in beam search.
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
Returns:
BatchHypothesis: The new running hypotheses.
"""
n_batch = running_hyps.yseq.shape[0]
logging.debug(f"the number of running hypothes: {n_batch}")
if self.token_list is not None:
logging.debug(
"best hypo: "
+ "".join(
[
self.token_list[x]
for x in running_hyps.yseq[0, 1 : running_hyps.length[0]]
]
)
)
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
yseq_eos = torch.cat(
(
running_hyps.yseq,
torch.full(
(n_batch, 1),
self.eos,
device=running_hyps.yseq.device,
dtype=torch.int64,
),
),
1,
)
running_hyps.yseq.resize_as_(yseq_eos)
running_hyps.yseq[:] = yseq_eos
running_hyps.length[:] = yseq_eos.shape[1]
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a probmlem, number of hyps < beam)
is_eos = (
running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1]
== self.eos
)
for b in torch.nonzero(is_eos, as_tuple=False).view(-1):
hyp = self._select(running_hyps, b)
ended_hyps.append(hyp)
remained_ids = torch.nonzero(is_eos == 0, as_tuple=False).view(-1).cpu()
return self._batch_select(running_hyps, remained_ids)
"""Parallel beam search module for online simulation."""
import logging
from typing import Any # noqa: H301
from typing import Dict # noqa: H301
from typing import List # noqa: H301
from typing import Tuple # noqa: H301
import torch
from espnet.nets.batch_beam_search import BatchBeamSearch # noqa: H301
from espnet.nets.batch_beam_search import BatchHypothesis # noqa: H301
from espnet.nets.beam_search import Hypothesis
from espnet.nets.e2e_asr_common import end_detect
class BatchBeamSearchOnline(BatchBeamSearch):
"""Online beam search implementation.
This simulates streaming decoding.
It requires encoded features of entire utterance and
extracts block by block from it as it shoud be done
in streaming processing.
This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
(https://arxiv.org/abs/2006.14941).
"""
def __init__(
self,
*args,
block_size=40,
hop_size=16,
look_ahead=16,
disable_repetition_detection=False,
encoded_feat_length_limit=0,
decoder_text_length_limit=0,
**kwargs,
):
"""Initialize beam search."""
super().__init__(*args, **kwargs)
self.block_size = block_size
self.hop_size = hop_size
self.look_ahead = look_ahead
self.disable_repetition_detection = disable_repetition_detection
self.encoded_feat_length_limit = encoded_feat_length_limit
self.decoder_text_length_limit = decoder_text_length_limit
self.reset()
def reset(self):
"""Reset parameters."""
self.encbuffer = None
self.running_hyps = None
self.prev_hyps = []
self.ended_hyps = []
self.processed_block = 0
self.process_idx = 0
self.prev_output = None
def score_full(
self, hyp: BatchHypothesis, x: torch.Tensor
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.full_scorers`
and tensor score values of shape: `(self.n_vocab,)`,
and state dict that has string keys
and state values of `self.full_scorers`
"""
scores = dict()
states = dict()
for k, d in self.full_scorers.items():
if (
self.decoder_text_length_limit > 0
and len(hyp.yseq) > 0
and len(hyp.yseq[0]) > self.decoder_text_length_limit
):
temp_yseq = hyp.yseq.narrow(
1, -self.decoder_text_length_limit, self.decoder_text_length_limit
).clone()
temp_yseq[:, 0] = self.sos
self.running_hyps.states["decoder"] = [
None for _ in self.running_hyps.states["decoder"]
]
scores[k], states[k] = d.batch_score(temp_yseq, hyp.states[k], x)
else:
scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x)
return scores, states
def forward(
self,
x: torch.Tensor,
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
is_final: bool = True,
) -> List[Hypothesis]:
"""Perform beam search.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
Returns:
list[Hypothesis]: N-best decoding results
"""
if self.encbuffer is None:
self.encbuffer = x
else:
self.encbuffer = torch.cat([self.encbuffer, x], axis=0)
x = self.encbuffer
# set length bounds
if maxlenratio == 0:
maxlen = x.shape[0]
else:
maxlen = max(1, int(maxlenratio * x.size(0)))
ret = None
while True:
cur_end_frame = (
self.block_size - self.look_ahead + self.hop_size * self.processed_block
)
if cur_end_frame < x.shape[0]:
h = x.narrow(0, 0, cur_end_frame)
block_is_final = False
else:
if is_final:
h = x
block_is_final = True
else:
break
logging.debug("Start processing block: %d", self.processed_block)
logging.debug(
" Feature length: {}, current position: {}".format(
h.shape[0], self.process_idx
)
)
if (
self.encoded_feat_length_limit > 0
and h.shape[0] > self.encoded_feat_length_limit
):
h = h.narrow(
0,
h.shape[0] - self.encoded_feat_length_limit,
self.encoded_feat_length_limit,
)
if self.running_hyps is None:
self.running_hyps = self.init_hyp(h)
ret = self.process_one_block(h, block_is_final, maxlen, maxlenratio)
logging.debug("Finished processing block: %d", self.processed_block)
self.processed_block += 1
if block_is_final:
return ret
if ret is None:
if self.prev_output is None:
return []
else:
return self.prev_output
else:
self.prev_output = ret
# N-best results
return ret
def process_one_block(self, h, is_final, maxlen, maxlenratio):
"""Recognize one block."""
# extend states for ctc
self.extend(h, self.running_hyps)
while self.process_idx < maxlen:
logging.debug("position " + str(self.process_idx))
best = self.search(self.running_hyps, h)
if self.process_idx == maxlen - 1:
# end decoding
self.running_hyps = self.post_process(
self.process_idx, maxlen, maxlenratio, best, self.ended_hyps
)
n_batch = best.yseq.shape[0]
local_ended_hyps = []
is_local_eos = best.yseq[torch.arange(n_batch), best.length - 1] == self.eos
prev_repeat = False
for i in range(is_local_eos.shape[0]):
if is_local_eos[i]:
hyp = self._select(best, i)
local_ended_hyps.append(hyp)
# NOTE(tsunoo): check repetitions here
# This is a implicit implementation of
# Eq (11) in https://arxiv.org/abs/2006.14941
# A flag prev_repeat is used instead of using set
# NOTE(fujihara): I made it possible to turned off
# the below lines using disable_repetition_detection flag,
# because this criteria is too sensitive that the beam
# search starts only after the entire inputs are available.
# Empirically, this flag didn't affect the performance.
elif (
not self.disable_repetition_detection
and not prev_repeat
and best.yseq[i, -1] in best.yseq[i, :-1]
and not is_final
):
prev_repeat = True
if prev_repeat:
logging.info("Detected repetition.")
break
if (
is_final
and maxlenratio == 0.0
and end_detect(
[lh.asdict() for lh in self.ended_hyps], self.process_idx
)
):
logging.info(f"end detected at {self.process_idx}")
return self.assemble_hyps(self.ended_hyps)
if len(local_ended_hyps) > 0 and not is_final:
logging.info("Detected hyp(s) reaching EOS in this block.")
break
self.prev_hyps = self.running_hyps
self.running_hyps = self.post_process(
self.process_idx, maxlen, maxlenratio, best, self.ended_hyps
)
if is_final:
for hyp in local_ended_hyps:
self.ended_hyps.append(hyp)
if len(self.running_hyps) == 0:
logging.info("no hypothesis. Finish decoding.")
return self.assemble_hyps(self.ended_hyps)
else:
logging.debug(f"remained hypotheses: {len(self.running_hyps)}")
# increment number
self.process_idx += 1
if is_final:
return self.assemble_hyps(self.ended_hyps)
else:
for hyp in self.ended_hyps:
local_ended_hyps.append(hyp)
rets = self.assemble_hyps(local_ended_hyps)
if self.process_idx > 1 and len(self.prev_hyps) > 0:
self.running_hyps = self.prev_hyps
self.process_idx -= 1
self.prev_hyps = []
# N-best results
return rets
def assemble_hyps(self, ended_hyps):
"""Assemble the hypotheses."""
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
# check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0:
logging.warning(
"there is no N-best results, perform recognition "
"again with smaller minlenratio."
)
return []
# report the best result
best = nbest_hyps[0]
for k, v in best.scores.items():
logging.info(
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
)
logging.info(f"total log probability: {best.score:.2f}")
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None:
logging.info(
"best hypo: "
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ "\n"
)
return nbest_hyps
def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
"""Extend probabilities and states with more encoded chunks.
Args:
x (torch.Tensor): The extended encoder output feature
hyps (Hypothesis): Current list of hypothesis
Returns:
Hypothesis: The extended hypothesis
"""
for k, d in self.scorers.items():
if hasattr(d, "extend_prob"):
d.extend_prob(x)
if hasattr(d, "extend_state"):
hyps.states[k] = d.extend_state(hyps.states[k])
"""Parallel beam search module for online simulation."""
import logging
from pathlib import Path
from typing import List
import torch
import yaml
from espnet.nets.batch_beam_search import BatchBeamSearch
from espnet.nets.beam_search import Hypothesis
from espnet.nets.e2e_asr_common import end_detect
class BatchBeamSearchOnlineSim(BatchBeamSearch):
"""Online beam search implementation.
This simulates streaming decoding.
It requires encoded features of entire utterance and
extracts block by block from it as it shoud be done
in streaming processing.
This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
(https://arxiv.org/abs/2006.14941).
"""
def set_streaming_config(self, asr_config: str):
"""Set config file for streaming decoding.
Args:
asr_config (str): The config file for asr training
"""
train_config_file = Path(asr_config)
self.block_size = None
self.hop_size = None
self.look_ahead = None
config = None
with train_config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if "encoder_conf" in args.keys():
if "block_size" in args["encoder_conf"].keys():
self.block_size = args["encoder_conf"]["block_size"]
if "hop_size" in args["encoder_conf"].keys():
self.hop_size = args["encoder_conf"]["hop_size"]
if "look_ahead" in args["encoder_conf"].keys():
self.look_ahead = args["encoder_conf"]["look_ahead"]
elif "config" in args.keys():
config = args["config"]
if config is None:
logging.info(
"Cannot find config file for streaming decoding: "
+ "apply batch beam search instead."
)
return
if (
self.block_size is None or self.hop_size is None or self.look_ahead is None
) and config is not None:
config_file = Path(config)
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if "encoder_conf" in args.keys():
enc_args = args["encoder_conf"]
if enc_args and "block_size" in enc_args:
self.block_size = enc_args["block_size"]
if enc_args and "hop_size" in enc_args:
self.hop_size = enc_args["hop_size"]
if enc_args and "look_ahead" in enc_args:
self.look_ahead = enc_args["look_ahead"]
def set_block_size(self, block_size: int):
"""Set block size for streaming decoding.
Args:
block_size (int): The block size of encoder
"""
self.block_size = block_size
def set_hop_size(self, hop_size: int):
"""Set hop size for streaming decoding.
Args:
hop_size (int): The hop size of encoder
"""
self.hop_size = hop_size
def set_look_ahead(self, look_ahead: int):
"""Set look ahead size for streaming decoding.
Args:
look_ahead (int): The look ahead size of encoder
"""
self.look_ahead = look_ahead
def forward(
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
) -> List[Hypothesis]:
"""Perform beam search.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
Returns:
list[Hypothesis]: N-best decoding results
"""
self.conservative = True # always true
if self.block_size and self.hop_size and self.look_ahead:
cur_end_frame = int(self.block_size - self.look_ahead)
else:
cur_end_frame = x.shape[0]
process_idx = 0
if cur_end_frame < x.shape[0]:
h = x.narrow(0, 0, cur_end_frame)
else:
h = x
# set length bounds
if maxlenratio == 0:
maxlen = x.shape[0]
else:
maxlen = max(1, int(maxlenratio * x.size(0)))
minlen = int(minlenratio * x.size(0))
logging.info("decoder input length: " + str(x.shape[0]))
logging.info("max output length: " + str(maxlen))
logging.info("min output length: " + str(minlen))
# main loop of prefix search
running_hyps = self.init_hyp(h)
prev_hyps = []
ended_hyps = []
prev_repeat = False
continue_decode = True
while continue_decode:
move_to_next_block = False
if cur_end_frame < x.shape[0]:
h = x.narrow(0, 0, cur_end_frame)
else:
h = x
# extend states for ctc
self.extend(h, running_hyps)
while process_idx < maxlen:
logging.debug("position " + str(process_idx))
best = self.search(running_hyps, h)
if process_idx == maxlen - 1:
# end decoding
running_hyps = self.post_process(
process_idx, maxlen, maxlenratio, best, ended_hyps
)
n_batch = best.yseq.shape[0]
local_ended_hyps = []
is_local_eos = (
best.yseq[torch.arange(n_batch), best.length - 1] == self.eos
)
for i in range(is_local_eos.shape[0]):
if is_local_eos[i]:
hyp = self._select(best, i)
local_ended_hyps.append(hyp)
# NOTE(tsunoo): check repetitions here
# This is a implicit implementation of
# Eq (11) in https://arxiv.org/abs/2006.14941
# A flag prev_repeat is used instead of using set
elif (
not prev_repeat
and best.yseq[i, -1] in best.yseq[i, :-1]
and cur_end_frame < x.shape[0]
):
move_to_next_block = True
prev_repeat = True
if maxlenratio == 0.0 and end_detect(
[lh.asdict() for lh in local_ended_hyps], process_idx
):
logging.info(f"end detected at {process_idx}")
continue_decode = False
break
if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]:
move_to_next_block = True
if move_to_next_block:
if (
self.hop_size
and cur_end_frame + int(self.hop_size) + int(self.look_ahead)
< x.shape[0]
):
cur_end_frame += int(self.hop_size)
else:
cur_end_frame = x.shape[0]
logging.debug("Going to next block: %d", cur_end_frame)
if process_idx > 1 and len(prev_hyps) > 0 and self.conservative:
running_hyps = prev_hyps
process_idx -= 1
prev_hyps = []
break
prev_repeat = False
prev_hyps = running_hyps
running_hyps = self.post_process(
process_idx, maxlen, maxlenratio, best, ended_hyps
)
if cur_end_frame >= x.shape[0]:
for hyp in local_ended_hyps:
ended_hyps.append(hyp)
if len(running_hyps) == 0:
logging.info("no hypothesis. Finish decoding.")
continue_decode = False
break
else:
logging.debug(f"remained hypotheses: {len(running_hyps)}")
# increment number
process_idx += 1
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
# check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0:
logging.warning(
"there is no N-best results, perform recognition "
"again with smaller minlenratio."
)
return (
[]
if minlenratio < 0.1
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
)
# report the best result
best = nbest_hyps[0]
for k, v in best.scores.items():
logging.info(
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
)
logging.info(f"total log probability: {best.score:.2f}")
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None:
logging.info(
"best hypo: "
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ "\n"
)
if best.yseq[1:-1].shape[0] == x.shape[0]:
logging.warning(
"best hypo length: {} == max output length: {}".format(
best.yseq[1:-1].shape[0], maxlen
)
)
logging.warning(
"decoding may be stopped by the max output length limitation, "
+ "please consider to increase the maxlenratio."
)
return nbest_hyps
def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
"""Extend probabilities and states with more encoded chunks.
Args:
x (torch.Tensor): The extended encoder output feature
hyps (Hypothesis): Current list of hypothesis
Returns:
Hypothesis: The extended hypothesis
"""
for k, d in self.scorers.items():
if hasattr(d, "extend_prob"):
d.extend_prob(x)
if hasattr(d, "extend_state"):
hyps.states[k] = d.extend_state(hyps.states[k])
"""Beam search module."""
import logging
from itertools import chain
from typing import Any, Dict, List, NamedTuple, Tuple, Union
import torch
from espnet.nets.e2e_asr_common import end_detect
from espnet.nets.scorer_interface import PartialScorerInterface, ScorerInterface
class Hypothesis(NamedTuple):
"""Hypothesis data type."""
yseq: torch.Tensor
score: Union[float, torch.Tensor] = 0
scores: Dict[str, Union[float, torch.Tensor]] = dict()
states: Dict[str, Any] = dict()
def asdict(self) -> dict:
"""Convert data to JSON-friendly dict."""
return self._replace(
yseq=self.yseq.tolist(),
score=float(self.score),
scores={k: float(v) for k, v in self.scores.items()},
)._asdict()
class BeamSearch(torch.nn.Module):
"""Beam search implementation."""
def __init__(
self,
scorers: Dict[str, ScorerInterface],
weights: Dict[str, float],
beam_size: int,
vocab_size: int,
sos: int,
eos: int,
token_list: List[str] = None,
pre_beam_ratio: float = 1.5,
pre_beam_score_key: str = None,
hyp_primer: List[int] = None,
):
"""Initialize beam search.
Args:
scorers (dict[str, ScorerInterface]): Dict of decoder modules
e.g., Decoder, CTCPrefixScorer, LM
The scorer will be ignored if it is `None`
weights (dict[str, float]): Dict of weights for each scorers
The scorer will be ignored if its weight is 0
beam_size (int): The number of hypotheses kept during search
vocab_size (int): The number of vocabulary
sos (int): Start of sequence id
eos (int): End of sequence id
token_list (list[str]): List of tokens for debug log
pre_beam_score_key (str): key of scores to perform pre-beam search
pre_beam_ratio (float): beam size in the pre-beam search
will be `int(pre_beam_ratio * beam_size)`
"""
super().__init__()
# set scorers
self.weights = weights
self.scorers = dict()
self.full_scorers = dict()
self.part_scorers = dict()
# this module dict is required for recursive cast
# `self.to(device, dtype)` in `recog.py`
self.nn_dict = torch.nn.ModuleDict()
for k, v in scorers.items():
w = weights.get(k, 0)
if w == 0 or v is None:
continue
assert isinstance(
v, ScorerInterface
), f"{k} ({type(v)}) does not implement ScorerInterface"
self.scorers[k] = v
if isinstance(v, PartialScorerInterface):
self.part_scorers[k] = v
else:
self.full_scorers[k] = v
if isinstance(v, torch.nn.Module):
self.nn_dict[k] = v
# set configurations
self.sos = sos
self.eos = eos
# added for OpenAI Whisper decoding
self.hyp_primer = hyp_primer
self.token_list = token_list
self.pre_beam_size = int(pre_beam_ratio * beam_size)
self.beam_size = beam_size
self.n_vocab = vocab_size
if (
pre_beam_score_key is not None
and pre_beam_score_key != "full"
and pre_beam_score_key not in self.full_scorers
):
raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
self.pre_beam_score_key = pre_beam_score_key
self.do_pre_beam = (
self.pre_beam_score_key is not None
and self.pre_beam_size < self.n_vocab
and len(self.part_scorers) > 0
)
def set_hyp_primer(self, hyp_primer: List[int] = None) -> None:
"""Set the primer sequence for decoding.
Used for OpenAI Whisper models.
"""
self.hyp_primer = hyp_primer
def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
"""Get an initial hypothesis data.
Args:
x (torch.Tensor): The encoder output feature
Returns:
Hypothesis: The initial hypothesis.
"""
init_states = dict()
init_scores = dict()
for k, d in self.scorers.items():
init_states[k] = d.init_state(x)
init_scores[k] = 0.0
# NOTE (Shih-Lun): added for OpenAI Whisper ASR
primer = [self.sos] if self.hyp_primer is None else self.hyp_primer
return [
Hypothesis(
score=0.0,
scores=init_scores,
states=init_states,
yseq=torch.tensor(primer, device=x.device),
)
]
@staticmethod
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
"""Append new token to prefix tokens.
Args:
xs (torch.Tensor): The prefix token
x (int): The new token to append
Returns:
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
"""
x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
return torch.cat((xs, x))
def score_full(
self, hyp: Hypothesis, x: torch.Tensor
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.full_scorers`
and tensor score values of shape: `(self.n_vocab,)`,
and state dict that has string keys
and state values of `self.full_scorers`
"""
scores = dict()
states = dict()
for k, d in self.full_scorers.items():
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
return scores, states
def score_partial(
self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
"""Score new hypothesis by `self.part_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
ids (torch.Tensor): 1D tensor of new partial tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.part_scorers`
and tensor score values of shape: `(len(ids),)`,
and state dict that has string keys
and state values of `self.part_scorers`
"""
scores = dict()
states = dict()
for k, d in self.part_scorers.items():
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
return scores, states
def beam(
self, weighted_scores: torch.Tensor, ids: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute topk full token ids and partial token ids.
Args:
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
Its shape is `(self.n_vocab,)`.
ids (torch.Tensor): The partial token ids to compute topk
Returns:
Tuple[torch.Tensor, torch.Tensor]:
The topk full token ids and partial token ids.
Their shapes are `(self.beam_size,)`
"""
# no pre beam performed
if weighted_scores.size(0) == ids.size(0):
top_ids = weighted_scores.topk(self.beam_size)[1]
return top_ids, top_ids
# mask pruned in pre-beam not to select in topk
tmp = weighted_scores[ids]
weighted_scores[:] = -float("inf")
weighted_scores[ids] = tmp
top_ids = weighted_scores.topk(self.beam_size)[1]
local_ids = weighted_scores[ids].topk(self.beam_size)[1]
return top_ids, local_ids
@staticmethod
def merge_scores(
prev_scores: Dict[str, float],
next_full_scores: Dict[str, torch.Tensor],
full_idx: int,
next_part_scores: Dict[str, torch.Tensor],
part_idx: int,
) -> Dict[str, torch.Tensor]:
"""Merge scores for new hypothesis.
Args:
prev_scores (Dict[str, float]):
The previous hypothesis scores by `self.scorers`
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
full_idx (int): The next token id for `next_full_scores`
next_part_scores (Dict[str, torch.Tensor]):
scores of partial tokens by `self.part_scorers`
part_idx (int): The new token id for `next_part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are scalar tensors by the scorers.
"""
new_scores = dict()
for k, v in next_full_scores.items():
new_scores[k] = prev_scores[k] + v[full_idx]
for k, v in next_part_scores.items():
new_scores[k] = prev_scores[k] + v[part_idx]
return new_scores
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
"""Merge states for new hypothesis.
Args:
states: states of `self.full_scorers`
part_states: states of `self.part_scorers`
part_idx (int): The new token id for `part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are states of the scorers.
"""
new_states = dict()
for k, v in states.items():
new_states[k] = v
for k, d in self.part_scorers.items():
new_states[k] = d.select_state(part_states[k], part_idx)
return new_states
def search(
self, running_hyps: List[Hypothesis], x: torch.Tensor
) -> List[Hypothesis]:
"""Search new tokens for running hypotheses and encoded speech x.
Args:
running_hyps (List[Hypothesis]): Running hypotheses on beam
x (torch.Tensor): Encoded speech feature (T, D)
Returns:
List[Hypotheses]: Best sorted hypotheses
"""
best_hyps = []
part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
for hyp in running_hyps:
# scoring
weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
scores, states = self.score_full(hyp, x)
for k in self.full_scorers:
weighted_scores += self.weights[k] * scores[k]
# partial scoring
if self.do_pre_beam:
pre_beam_scores = (
weighted_scores
if self.pre_beam_score_key == "full"
else scores[self.pre_beam_score_key]
)
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
part_scores, part_states = self.score_partial(hyp, part_ids, x)
for k in self.part_scorers:
weighted_scores[part_ids] += self.weights[k] * part_scores[k]
# add previous hyp score
weighted_scores += hyp.score
# update hyps
for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
# will be (2 x beam at most)
best_hyps.append(
Hypothesis(
score=weighted_scores[j],
yseq=self.append_token(hyp.yseq, j),
scores=self.merge_scores(
hyp.scores, scores, j, part_scores, part_j
),
states=self.merge_states(states, part_states, part_j),
)
)
# sort and prune 2 x beam -> beam
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
: min(len(best_hyps), self.beam_size)
]
return best_hyps
def forward(
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
) -> List[Hypothesis]:
"""Perform beam search.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
If maxlenratio<0.0, its absolute value is interpreted
as a constant max output length.
minlenratio (float): Input length ratio to obtain min output length.
If minlenratio<0.0, its absolute value is interpreted
as a constant min output length.
Returns:
list[Hypothesis]: N-best decoding results
"""
# set length bounds
if maxlenratio == 0:
maxlen = x.shape[0]
elif maxlenratio < 0:
maxlen = -1 * int(maxlenratio)
else:
maxlen = max(1, int(maxlenratio * x.size(0)))
if minlenratio < 0:
minlen = -1 * int(minlenratio)
else:
minlen = int(minlenratio * x.size(0))
logging.info("decoder input length: " + str(x.shape[0]))
logging.info("max output length: " + str(maxlen))
logging.info("min output length: " + str(minlen))
# main loop of prefix search
running_hyps = self.init_hyp(x)
ended_hyps = []
for i in range(maxlen):
logging.debug("position " + str(i))
best = self.search(running_hyps, x)
# post process of one iteration
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
# end detection
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
logging.info(f"end detected at {i}")
break
if len(running_hyps) == 0:
logging.info("no hypothesis. Finish decoding.")
break
else:
logging.debug(f"remained hypotheses: {len(running_hyps)}")
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
# check the number of hypotheses reaching to eos
if len(nbest_hyps) == 0:
logging.warning(
"there is no N-best results, perform recognition "
"again with smaller minlenratio."
)
return (
[]
if minlenratio < 0.1
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
)
# report the best result
best = nbest_hyps[0]
for k, v in best.scores.items():
logging.info(
f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
)
logging.info(f"total log probability: {best.score:.2f}")
logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
if self.token_list is not None:
logging.info(
"best hypo: "
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ "\n"
)
if best.yseq[1:-1].shape[0] == maxlen:
logging.warning(
"best hypo length: {} == max output length: {}".format(
best.yseq[1:-1].shape[0], maxlen
)
)
logging.warning(
"decoding may be stopped by the max output length limitation, "
+ "please consider to increase the maxlenratio."
)
return nbest_hyps
def post_process(
self,
i: int,
maxlen: int,
maxlenratio: float,
running_hyps: List[Hypothesis],
ended_hyps: List[Hypothesis],
) -> List[Hypothesis]:
"""Perform post-processing of beam search iterations.
Args:
i (int): The length of hypothesis tokens.
maxlen (int): The maximum length of tokens in beam search.
maxlenratio (int): The maximum length ratio in beam search.
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
Returns:
List[Hypothesis]: The new running hypotheses.
"""
logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
if self.token_list is not None:
logging.debug(
"best hypo: "
+ "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
)
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
running_hyps = [
h._replace(yseq=self.append_token(h.yseq, self.eos))
for h in running_hyps
]
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a problem, number of hyps < beam)
remained_hyps = []
for hyp in running_hyps:
if hyp.yseq[-1] == self.eos:
# e.g., Word LM needs to add final <eos> score
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
s = d.final_score(hyp.states[k])
hyp.scores[k] += s
hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
return remained_hyps
def beam_search(
x: torch.Tensor,
sos: int,
eos: int,
beam_size: int,
vocab_size: int,
scorers: Dict[str, ScorerInterface],
weights: Dict[str, float],
token_list: List[str] = None,
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
pre_beam_ratio: float = 1.5,
pre_beam_score_key: str = "full",
) -> list:
"""Perform beam search with scorers.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
sos (int): Start of sequence id
eos (int): End of sequence id
beam_size (int): The number of hypotheses kept during search
vocab_size (int): The number of vocabulary
scorers (dict[str, ScorerInterface]): Dict of decoder modules
e.g., Decoder, CTCPrefixScorer, LM
The scorer will be ignored if it is `None`
weights (dict[str, float]): Dict of weights for each scorers
The scorer will be ignored if its weight is 0
token_list (list[str]): List of tokens for debug log
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
pre_beam_score_key (str): key of scores to perform pre-beam search
pre_beam_ratio (float): beam size in the pre-beam search
will be `int(pre_beam_ratio * beam_size)`
Returns:
list: N-best decoding results
"""
ret = BeamSearch(
scorers,
weights,
beam_size=beam_size,
vocab_size=vocab_size,
pre_beam_ratio=pre_beam_ratio,
pre_beam_score_key=pre_beam_score_key,
sos=sos,
eos=eos,
token_list=token_list,
).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio)
return [h.asdict() for h in ret]
"""
Time Synchronous One-Pass Beam Search.
Implements joint CTC/attention decoding where
hypotheses are expanded along the time (input) axis,
as described in https://arxiv.org/abs/2210.05200.
Supports CPU and GPU inference.
References: https://arxiv.org/abs/1408.2873 for CTC beam search
Author: Brian Yan
"""
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
from espnet.nets.beam_search import Hypothesis
from espnet.nets.scorer_interface import ScorerInterface
@dataclass
class CacheItem:
"""For caching attentional decoder and LM states."""
state: Any
scores: Any
log_sum: float
class BeamSearchTimeSync(torch.nn.Module):
"""Time synchronous beam search algorithm."""
def __init__(
self,
sos: int,
beam_size: int,
scorers: Dict[str, ScorerInterface],
weights: Dict[str, float],
token_list=dict,
pre_beam_ratio: float = 1.5,
blank: int = 0,
force_lid: bool = False,
temp: float = 1.0,
):
"""Initialize beam search.
Args:
beam_size: num hyps
sos: sos index
ctc: CTC module
pre_beam_ratio: pre_beam_ratio * beam_size = pre_beam
pre_beam is used to select candidates from vocab to extend hypotheses
decoder: decoder ScorerInterface
ctc_weight: ctc_weight
blank: blank index
"""
super().__init__()
self.ctc = scorers["ctc"]
self.decoder = scorers["decoder"]
self.lm = scorers["lm"] if "lm" in scorers else None
self.beam_size = beam_size
self.pre_beam_size = int(pre_beam_ratio * beam_size)
self.ctc_weight = weights["ctc"]
self.lm_weight = weights["lm"]
self.decoder_weight = weights["decoder"]
self.penalty = weights["length_bonus"]
self.sos = sos
self.sos_th = torch.tensor([self.sos])
self.blank = blank
self.attn_cache = dict() # cache for p_attn(Y|X)
self.lm_cache = dict() # cache for p_lm(Y)
self.enc_output = None # log p_ctc(Z|X)
self.force_lid = force_lid
self.temp = temp
self.token_list = token_list
def reset(self, enc_output: torch.Tensor):
"""Reset object for a new utterance."""
self.attn_cache = dict()
self.lm_cache = dict()
self.enc_output = enc_output
self.sos_th = self.sos_th.to(enc_output.device)
if self.decoder is not None:
init_decoder_state = self.decoder.init_state(enc_output)
decoder_scores, decoder_state = self.decoder.score(
self.sos_th, init_decoder_state, enc_output
)
self.attn_cache[(self.sos,)] = CacheItem(
state=decoder_state,
scores=decoder_scores,
log_sum=0.0,
)
if self.lm is not None:
init_lm_state = self.lm.init_state(enc_output)
lm_scores, lm_state = self.lm.score(self.sos_th, init_lm_state, enc_output)
self.lm_cache[(self.sos,)] = CacheItem(
state=lm_state,
scores=lm_scores,
log_sum=0.0,
)
def cached_score(self, h: Tuple[int], cache: dict, scorer: ScorerInterface) -> Any:
"""Retrieve decoder/LM scores which may be cached."""
root = h[:-1] # prefix
if root in cache:
root_scores = cache[root].scores
root_state = cache[root].state
root_log_sum = cache[root].log_sum
else: # run decoder fwd one step and update cache
root_root = root[:-1]
root_root_state = cache[root_root].state
root_scores, root_state = scorer.score(
torch.tensor(root, device=self.enc_output.device).long(),
root_root_state,
self.enc_output,
)
root_log_sum = cache[root_root].log_sum + float(
cache[root_root].scores[root[-1]]
)
cache[root] = CacheItem(
state=root_state, scores=root_scores, log_sum=root_log_sum
)
cand_score = float(root_scores[h[-1]])
score = root_log_sum + cand_score
return score
def joint_score(self, hyps: Any, ctc_score_dp: Any) -> Any:
"""Calculate joint score for hyps."""
scores = dict()
for h in hyps:
score = self.ctc_weight * np.logaddexp(*ctc_score_dp[h]) # ctc score
if len(h) > 1 and self.decoder_weight > 0 and self.decoder is not None:
score += (
self.cached_score(h, self.attn_cache, self.decoder)
* self.decoder_weight
) # attn score
if len(h) > 1 and self.lm is not None and self.lm_weight > 0:
score += (
self.cached_score(h, self.lm_cache, self.lm) * self.lm_weight
) # lm score
score += self.penalty * (len(h) - 1) # penalty score
scores[h] = score
return scores
def time_step(self, p_ctc: Any, ctc_score_dp: Any, hyps: Any) -> Any:
"""Execute a single time step."""
pre_beam_threshold = np.sort(p_ctc)[-self.pre_beam_size]
cands = set(np.where(p_ctc >= pre_beam_threshold)[0])
if len(cands) == 0:
cands = {np.argmax(p_ctc)}
new_hyps = set()
ctc_score_dp_next = defaultdict(
lambda: (float("-inf"), float("-inf"))
) # (p_nb, p_b)
tmp = []
for hyp_l in hyps:
p_prev_l = np.logaddexp(*ctc_score_dp[hyp_l])
for c in cands:
if c == self.blank:
logging.debug("blank cand, hypothesis is " + str(hyp_l))
p_nb, p_b = ctc_score_dp_next[hyp_l]
p_b = np.logaddexp(p_b, p_ctc[c] + p_prev_l)
ctc_score_dp_next[hyp_l] = (p_nb, p_b)
new_hyps.add(hyp_l)
else:
l_plus = hyp_l + (int(c),)
logging.debug("non-blank cand, hypothesis is " + str(l_plus))
p_nb, p_b = ctc_score_dp_next[l_plus]
if c == hyp_l[-1]:
logging.debug("repeat cand, hypothesis is " + str(hyp_l))
p_nb_prev, p_b_prev = ctc_score_dp[hyp_l]
p_nb = np.logaddexp(p_nb, p_ctc[c] + p_b_prev)
p_nb_l, p_b_l = ctc_score_dp_next[hyp_l]
p_nb_l = np.logaddexp(p_nb_l, p_ctc[c] + p_nb_prev)
ctc_score_dp_next[hyp_l] = (p_nb_l, p_b_l)
else:
p_nb = np.logaddexp(p_nb, p_ctc[c] + p_prev_l)
if l_plus not in hyps and l_plus in ctc_score_dp:
p_b = np.logaddexp(
p_b, p_ctc[self.blank] + np.logaddexp(*ctc_score_dp[l_plus])
)
p_nb = np.logaddexp(p_nb, p_ctc[c] + ctc_score_dp[l_plus][0])
tmp.append(l_plus)
ctc_score_dp_next[l_plus] = (p_nb, p_b)
new_hyps.add(l_plus)
scores = self.joint_score(new_hyps, ctc_score_dp_next)
hyps = sorted(new_hyps, key=lambda ll: scores[ll], reverse=True)[
: self.beam_size
]
ctc_score_dp = ctc_score_dp_next.copy()
return ctc_score_dp, hyps, scores
def forward(
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
) -> List[Hypothesis]:
"""Perform beam search.
Args:
enc_output (torch.Tensor)
Return:
list[Hypothesis]
"""
logging.info("decoder input lengths: " + str(x.shape[0]))
lpz = self.ctc.log_softmax(x.unsqueeze(0))
lpz = lpz.squeeze(0)
lpz = lpz.cpu().detach().numpy()
self.reset(x)
hyps = [(self.sos,)]
ctc_score_dp = defaultdict(
lambda: (float("-inf"), float("-inf"))
) # (p_nb, p_b) - dp object tracking p_ctc
ctc_score_dp[(self.sos,)] = (float("-inf"), 0.0)
for t in range(lpz.shape[0]):
logging.debug("position " + str(t))
ctc_score_dp, hyps, scores = self.time_step(lpz[t, :], ctc_score_dp, hyps)
ret = [
Hypothesis(yseq=torch.tensor(list(h) + [self.sos]), score=scores[h])
for h in hyps
]
best_hyp = "".join([self.token_list[x] for x in ret[0].yseq.tolist()])
best_hyp_len = len(ret[0].yseq)
best_score = ret[0].score
logging.info(f"output length: {best_hyp_len}")
logging.info(f"total log probability: {best_score:.2f}")
logging.info(f"best hypo: {best_hyp}")
return ret
"""Search algorithms for Transducer models."""
import logging
from typing import List, Union
import numpy as np
import torch
from espnet.nets.pytorch_backend.transducer.custom_decoder import CustomDecoder
from espnet.nets.pytorch_backend.transducer.joint_network import JointNetwork
from espnet.nets.pytorch_backend.transducer.rnn_decoder import RNNDecoder
from espnet.nets.pytorch_backend.transducer.utils import (
create_lm_batch_states,
init_lm_state,
is_prefix,
recombine_hyps,
select_k_expansions,
select_lm_state,
subtract,
)
from espnet.nets.transducer_decoder_interface import ExtendedHypothesis, Hypothesis
class BeamSearchTransducer:
"""Beam search implementation for Transducer."""
def __init__(
self,
decoder: Union[RNNDecoder, CustomDecoder],
joint_network: JointNetwork,
beam_size: int,
lm: torch.nn.Module = None,
lm_weight: float = 0.1,
search_type: str = "default",
max_sym_exp: int = 2,
u_max: int = 50,
nstep: int = 1,
prefix_alpha: int = 1,
expansion_gamma: int = 2.3,
expansion_beta: int = 2,
score_norm: bool = True,
softmax_temperature: float = 1.0,
nbest: int = 1,
quantization: bool = False,
):
"""Initialize Transducer search module.
Args:
decoder: Decoder module.
joint_network: Joint network module.
beam_size: Beam size.
lm: LM class.
lm_weight: LM weight for soft fusion.
search_type: Search algorithm to use during inference.
max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
u_max: Maximum output sequence length. (ALSD)
nstep: Number of maximum expansion steps at each time step. (NSC/mAES)
prefix_alpha: Maximum prefix length in prefix search. (NSC/mAES)
expansion_beta:
Number of additional candidates for expanded hypotheses selection. (mAES)
expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
score_norm: Normalize final scores by length. ("default")
softmax_temperature: Penalization term for softmax function.
nbest: Number of final hypothesis.
quantization: Whether dynamic quantization is used.
"""
self.decoder = decoder
self.joint_network = joint_network
self.beam_size = beam_size
self.hidden_size = decoder.dunits
self.vocab_size = decoder.odim
self.blank_id = decoder.blank_id
if self.beam_size <= 1:
self.search_algorithm = self.greedy_search
elif search_type == "default":
self.search_algorithm = self.default_beam_search
elif search_type == "tsd":
self.max_sym_exp = max_sym_exp
self.search_algorithm = self.time_sync_decoding
elif search_type == "alsd":
self.u_max = u_max
self.search_algorithm = self.align_length_sync_decoding
elif search_type == "nsc":
self.nstep = nstep
self.prefix_alpha = prefix_alpha
self.search_algorithm = self.nsc_beam_search
elif search_type == "maes":
self.nstep = nstep if nstep > 1 else 2
self.prefix_alpha = prefix_alpha
self.expansion_gamma = expansion_gamma
assert self.vocab_size >= beam_size + expansion_beta, (
"beam_size (%d) + expansion_beta (%d) "
"should be smaller or equal to vocabulary size (%d)."
% (beam_size, expansion_beta, self.vocab_size)
)
self.max_candidates = beam_size + expansion_beta
self.search_algorithm = self.modified_adaptive_expansion_search
else:
raise NotImplementedError
if lm is not None:
self.use_lm = True
self.lm = lm
self.is_wordlm = True if hasattr(lm.predictor, "wordlm") else False
self.lm_predictor = lm.predictor.wordlm if self.is_wordlm else lm.predictor
self.lm_layers = len(self.lm_predictor.rnn)
self.lm_weight = lm_weight
else:
self.use_lm = False
if softmax_temperature > 1.0 and lm is not None:
logging.warning(
"Softmax temperature is not supported with LM decoding."
"Setting softmax-temperature value to 1.0."
)
self.softmax_temperature = 1.0
else:
self.softmax_temperature = softmax_temperature
self.quantization = quantization
self.score_norm = score_norm
self.nbest = nbest
def __call__(
self, enc_out: torch.Tensor
) -> Union[List[Hypothesis], List[ExtendedHypothesis]]:
"""Perform beam search.
Args:
enc_out: Encoder output sequence. (T, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
self.decoder.set_device(enc_out.device)
nbest_hyps = self.search_algorithm(enc_out)
return nbest_hyps
def sort_nbest(
self, hyps: Union[List[Hypothesis], List[ExtendedHypothesis]]
) -> Union[List[Hypothesis], List[ExtendedHypothesis]]:
"""Sort hypotheses by score or score given sequence length.
Args:
hyps: Hypothesis.
Return:
hyps: Sorted hypothesis.
"""
if self.score_norm:
hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
else:
hyps.sort(key=lambda x: x.score, reverse=True)
return hyps[: self.nbest]
def prefix_search(
self, hyps: List[ExtendedHypothesis], enc_out_t: torch.Tensor
) -> List[ExtendedHypothesis]:
"""Prefix search for NSC and mAES strategies.
Based on https://arxiv.org/pdf/1211.3711.pdf
"""
for j, hyp_j in enumerate(hyps[:-1]):
for hyp_i in hyps[(j + 1) :]:
curr_id = len(hyp_j.yseq)
pref_id = len(hyp_i.yseq)
if (
is_prefix(hyp_j.yseq, hyp_i.yseq)
and (curr_id - pref_id) <= self.prefix_alpha
):
logp = torch.log_softmax(
self.joint_network(
enc_out_t, hyp_i.dec_out[-1], quantization=self.quantization
)
/ self.softmax_temperature,
dim=-1,
)
curr_score = hyp_i.score + float(logp[hyp_j.yseq[pref_id]])
for k in range(pref_id, (curr_id - 1)):
logp = torch.log_softmax(
self.joint_network(
enc_out_t,
hyp_j.dec_out[k],
quantization=self.quantization,
)
/ self.softmax_temperature,
dim=-1,
)
curr_score += float(logp[hyp_j.yseq[k + 1]])
hyp_j.score = np.logaddexp(hyp_j.score, curr_score)
return hyps
def greedy_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
"""Greedy search implementation.
Args:
enc_out: Encoder output sequence. (T, D_enc)
Returns:
hyp: 1-best hypotheses.
"""
dec_state = self.decoder.init_state(1)
hyp = Hypothesis(score=0.0, yseq=[self.blank_id], dec_state=dec_state)
cache = {}
dec_out, state, _ = self.decoder.score(hyp, cache)
for enc_out_t in enc_out:
logp = torch.log_softmax(
self.joint_network(enc_out_t, dec_out, quantization=self.quantization)
/ self.softmax_temperature,
dim=-1,
)
top_logp, pred = torch.max(logp, dim=-1)
if pred != self.blank_id:
hyp.yseq.append(int(pred))
hyp.score += float(top_logp)
hyp.dec_state = state
dec_out, state, _ = self.decoder.score(hyp, cache)
return [hyp]
def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
"""Beam search implementation.
Modified from https://arxiv.org/pdf/1211.3711.pdf
Args:
enc_out: Encoder output sequence. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam = min(self.beam_size, self.vocab_size)
beam_k = min(beam, (self.vocab_size - 1))
dec_state = self.decoder.init_state(1)
kept_hyps = [Hypothesis(score=0.0, yseq=[self.blank_id], dec_state=dec_state)]
cache = {}
for enc_out_t in enc_out:
hyps = kept_hyps
kept_hyps = []
while True:
max_hyp = max(hyps, key=lambda x: x.score)
hyps.remove(max_hyp)
dec_out, state, lm_tokens = self.decoder.score(max_hyp, cache)
logp = torch.log_softmax(
self.joint_network(
enc_out_t, dec_out, quantization=self.quantization
)
/ self.softmax_temperature,
dim=-1,
)
top_k = logp[1:].topk(beam_k, dim=-1)
kept_hyps.append(
Hypothesis(
score=(max_hyp.score + float(logp[0:1])),
yseq=max_hyp.yseq[:],
dec_state=max_hyp.dec_state,
lm_state=max_hyp.lm_state,
)
)
if self.use_lm:
lm_state, lm_scores = self.lm.predict(max_hyp.lm_state, lm_tokens)
else:
lm_state = max_hyp.lm_state
for logp, k in zip(*top_k):
score = max_hyp.score + float(logp)
if self.use_lm:
score += self.lm_weight * lm_scores[0][k + 1]
hyps.append(
Hypothesis(
score=score,
yseq=max_hyp.yseq[:] + [int(k + 1)],
dec_state=state,
lm_state=lm_state,
)
)
hyps_max = float(max(hyps, key=lambda x: x.score).score)
kept_most_prob = sorted(
[hyp for hyp in kept_hyps if hyp.score > hyps_max],
key=lambda x: x.score,
)
if len(kept_most_prob) >= beam:
kept_hyps = kept_most_prob
break
return self.sort_nbest(kept_hyps)
def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
"""Time synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
enc_out: Encoder output sequence. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam = min(self.beam_size, self.vocab_size)
beam_state = self.decoder.init_state(beam)
B = [
Hypothesis(
yseq=[self.blank_id],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
)
]
cache = {}
if self.use_lm and not self.is_wordlm:
B[0].lm_state = init_lm_state(self.lm_predictor)
for enc_out_t in enc_out:
A = []
C = B
enc_out_t = enc_out_t.unsqueeze(0)
for v in range(self.max_sym_exp):
D = []
beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
C,
beam_state,
cache,
self.use_lm,
)
beam_logp = torch.log_softmax(
self.joint_network(enc_out_t, beam_dec_out)
/ self.softmax_temperature,
dim=-1,
)
beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)
seq_A = [h.yseq for h in A]
for i, hyp in enumerate(C):
if hyp.yseq not in seq_A:
A.append(
Hypothesis(
score=(hyp.score + float(beam_logp[i, 0])),
yseq=hyp.yseq[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
)
)
else:
dict_pos = seq_A.index(hyp.yseq)
A[dict_pos].score = np.logaddexp(
A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
)
if v < (self.max_sym_exp - 1):
if self.use_lm:
beam_lm_states = create_lm_batch_states(
[c.lm_state for c in C], self.lm_layers, self.is_wordlm
)
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
beam_lm_states, beam_lm_tokens, len(C)
)
for i, hyp in enumerate(C):
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
new_hyp = Hypothesis(
score=(hyp.score + float(logp)),
yseq=(hyp.yseq + [int(k)]),
dec_state=self.decoder.select_state(beam_state, i),
lm_state=hyp.lm_state,
)
if self.use_lm:
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
new_hyp.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
D.append(new_hyp)
C = sorted(D, key=lambda x: x.score, reverse=True)[:beam]
B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
return self.sort_nbest(B)
def align_length_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
"""Alignment-length synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
h: Encoder output sequences. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam = min(self.beam_size, self.vocab_size)
t_max = int(enc_out.size(0))
u_max = min(self.u_max, (t_max - 1))
beam_state = self.decoder.init_state(beam)
B = [
Hypothesis(
yseq=[self.blank_id],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
)
]
final = []
cache = {}
if self.use_lm and not self.is_wordlm:
B[0].lm_state = init_lm_state(self.lm_predictor)
for i in range(t_max + u_max):
A = []
B_ = []
B_enc_out = []
for hyp in B:
u = len(hyp.yseq) - 1
t = i - u
if t > (t_max - 1):
continue
B_.append(hyp)
B_enc_out.append((t, enc_out[t]))
if B_:
beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
B_,
beam_state,
cache,
self.use_lm,
)
beam_enc_out = torch.stack([x[1] for x in B_enc_out])
beam_logp = torch.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out)
/ self.softmax_temperature,
dim=-1,
)
beam_topk = beam_logp[:, 1:].topk(beam, dim=-1)
if self.use_lm:
beam_lm_states = create_lm_batch_states(
[b.lm_state for b in B_], self.lm_layers, self.is_wordlm
)
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
beam_lm_states, beam_lm_tokens, len(B_)
)
for i, hyp in enumerate(B_):
new_hyp = Hypothesis(
score=(hyp.score + float(beam_logp[i, 0])),
yseq=hyp.yseq[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
)
A.append(new_hyp)
if B_enc_out[i][0] == (t_max - 1):
final.append(new_hyp)
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
new_hyp = Hypothesis(
score=(hyp.score + float(logp)),
yseq=(hyp.yseq[:] + [int(k)]),
dec_state=self.decoder.select_state(beam_state, i),
lm_state=hyp.lm_state,
)
if self.use_lm:
new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
new_hyp.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
A.append(new_hyp)
B = sorted(A, key=lambda x: x.score, reverse=True)[:beam]
B = recombine_hyps(B)
if final:
return self.sort_nbest(final)
else:
return B
def nsc_beam_search(self, enc_out: torch.Tensor) -> List[ExtendedHypothesis]:
"""N-step constrained beam search implementation.
Based on/Modified from https://arxiv.org/pdf/2002.03577.pdf.
Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet
until further modifications.
Args:
enc_out: Encoder output sequence. (T, D_enc)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam = min(self.beam_size, self.vocab_size)
beam_k = min(beam, (self.vocab_size - 1))
beam_state = self.decoder.init_state(beam)
init_tokens = [
ExtendedHypothesis(
yseq=[self.blank_id],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
)
]
cache = {}
beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
init_tokens,
beam_state,
cache,
self.use_lm,
)
state = self.decoder.select_state(beam_state, 0)
if self.use_lm:
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
None, beam_lm_tokens, 1
)
lm_state = select_lm_state(
beam_lm_states, 0, self.lm_layers, self.is_wordlm
)
lm_scores = beam_lm_scores[0]
else:
lm_state = None
lm_scores = None
kept_hyps = [
ExtendedHypothesis(
yseq=[self.blank_id],
score=0.0,
dec_state=state,
dec_out=[beam_dec_out[0]],
lm_state=lm_state,
lm_scores=lm_scores,
)
]
for enc_out_t in enc_out:
hyps = self.prefix_search(
sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True),
enc_out_t,
)
kept_hyps = []
beam_enc_out = enc_out_t.unsqueeze(0)
S = []
V = []
for n in range(self.nstep):
beam_dec_out = torch.stack([hyp.dec_out[-1] for hyp in hyps])
beam_logp = torch.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out)
/ self.softmax_temperature,
dim=-1,
)
beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1)
for i, hyp in enumerate(hyps):
S.append(
ExtendedHypothesis(
yseq=hyp.yseq[:],
score=hyp.score + float(beam_logp[i, 0:1]),
dec_out=hyp.dec_out[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
lm_scores=hyp.lm_scores,
)
)
for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
score = hyp.score + float(logp)
if self.use_lm:
score += self.lm_weight * float(hyp.lm_scores[k])
V.append(
ExtendedHypothesis(
yseq=hyp.yseq[:] + [int(k)],
score=score,
dec_out=hyp.dec_out[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
lm_scores=hyp.lm_scores,
)
)
V.sort(key=lambda x: x.score, reverse=True)
V = subtract(V, hyps)[:beam]
beam_state = self.decoder.create_batch_states(
beam_state,
[v.dec_state for v in V],
[v.yseq for v in V],
)
beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
V,
beam_state,
cache,
self.use_lm,
)
if self.use_lm:
beam_lm_states = create_lm_batch_states(
[v.lm_state for v in V], self.lm_layers, self.is_wordlm
)
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
beam_lm_states, beam_lm_tokens, len(V)
)
if n < (self.nstep - 1):
for i, v in enumerate(V):
v.dec_out.append(beam_dec_out[i])
v.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
v.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
v.lm_scores = beam_lm_scores[i]
hyps = V[:]
else:
beam_logp = torch.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out)
/ self.softmax_temperature,
dim=-1,
)
for i, v in enumerate(V):
if self.nstep != 1:
v.score += float(beam_logp[i, 0])
v.dec_out.append(beam_dec_out[i])
v.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
v.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
v.lm_scores = beam_lm_scores[i]
kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam]
return self.sort_nbest(kept_hyps)
def modified_adaptive_expansion_search(
self, enc_out: torch.Tensor
) -> List[ExtendedHypothesis]:
"""It's the modified Adaptive Expansion Search (mAES) implementation.
Based on/modified from https://ieeexplore.ieee.org/document/9250505 and NSC.
Args:
enc_out: Encoder output sequence. (T, D_enc)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam = min(self.beam_size, self.vocab_size)
beam_state = self.decoder.init_state(beam)
init_tokens = [
ExtendedHypothesis(
yseq=[self.blank_id],
score=0.0,
dec_state=self.decoder.select_state(beam_state, 0),
)
]
cache = {}
beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
init_tokens,
beam_state,
cache,
self.use_lm,
)
state = self.decoder.select_state(beam_state, 0)
if self.use_lm:
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
None, beam_lm_tokens, 1
)
lm_state = select_lm_state(
beam_lm_states, 0, self.lm_layers, self.is_wordlm
)
lm_scores = beam_lm_scores[0]
else:
lm_state = None
lm_scores = None
kept_hyps = [
ExtendedHypothesis(
yseq=[self.blank_id],
score=0.0,
dec_state=state,
dec_out=[beam_dec_out[0]],
lm_state=lm_state,
lm_scores=lm_scores,
)
]
for enc_out_t in enc_out:
hyps = self.prefix_search(
sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True),
enc_out_t,
)
kept_hyps = []
beam_enc_out = enc_out_t.unsqueeze(0)
list_b = []
duplication_check = [hyp.yseq for hyp in hyps]
for n in range(self.nstep):
beam_dec_out = torch.stack([h.dec_out[-1] for h in hyps])
beam_logp, beam_idx = torch.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out)
/ self.softmax_temperature,
dim=-1,
).topk(self.max_candidates, dim=-1)
k_expansions = select_k_expansions(
hyps,
beam_idx,
beam_logp,
self.expansion_gamma,
)
list_exp = []
for i, hyp in enumerate(hyps):
for k, new_score in k_expansions[i]:
new_hyp = ExtendedHypothesis(
yseq=hyp.yseq[:],
score=new_score,
dec_out=hyp.dec_out[:],
dec_state=hyp.dec_state,
lm_state=hyp.lm_state,
lm_scores=hyp.lm_scores,
)
if k == 0:
list_b.append(new_hyp)
else:
if new_hyp.yseq + [int(k)] not in duplication_check:
new_hyp.yseq.append(int(k))
if self.use_lm:
new_hyp.score += self.lm_weight * float(
hyp.lm_scores[k]
)
list_exp.append(new_hyp)
if not list_exp:
kept_hyps = sorted(list_b, key=lambda x: x.score, reverse=True)[
:beam
]
break
else:
beam_state = self.decoder.create_batch_states(
beam_state,
[hyp.dec_state for hyp in list_exp],
[hyp.yseq for hyp in list_exp],
)
beam_dec_out, beam_state, beam_lm_tokens = self.decoder.batch_score(
list_exp,
beam_state,
cache,
self.use_lm,
)
if self.use_lm:
beam_lm_states = create_lm_batch_states(
[hyp.lm_state for hyp in list_exp],
self.lm_layers,
self.is_wordlm,
)
beam_lm_states, beam_lm_scores = self.lm.buff_predict(
beam_lm_states, beam_lm_tokens, len(list_exp)
)
if n < (self.nstep - 1):
for i, hyp in enumerate(list_exp):
hyp.dec_out.append(beam_dec_out[i])
hyp.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
hyp.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
hyp.lm_scores = beam_lm_scores[i]
hyps = list_exp[:]
else:
beam_logp = torch.log_softmax(
self.joint_network(beam_enc_out, beam_dec_out)
/ self.softmax_temperature,
dim=-1,
)
for i, hyp in enumerate(list_exp):
hyp.score += float(beam_logp[i, 0])
hyp.dec_out.append(beam_dec_out[i])
hyp.dec_state = self.decoder.select_state(beam_state, i)
if self.use_lm:
hyp.lm_state = select_lm_state(
beam_lm_states, i, self.lm_layers, self.is_wordlm
)
hyp.lm_scores = beam_lm_scores[i]
kept_hyps = sorted(
list_b + list_exp, key=lambda x: x.score, reverse=True
)[:beam]
return self.sort_nbest(kept_hyps)
"""ASR Interface module."""
import chainer
from espnet.nets.asr_interface import ASRInterface
class ChainerASRInterface(ASRInterface, chainer.Chain):
"""ASR Interface for ESPnet model implementation."""
@staticmethod
def custom_converter(*args, **kw):
"""Get customconverter of the model (Chainer only)."""
raise NotImplementedError("custom converter method is not implemented")
@staticmethod
def custom_updater(*args, **kw):
"""Get custom_updater of the model (Chainer only)."""
raise NotImplementedError("custom updater method is not implemented")
@staticmethod
def custom_parallel_updater(*args, **kw):
"""Get custom_parallel_updater of the model (Chainer only)."""
raise NotImplementedError("custom parallel updater method is not implemented")
def get_total_subsampling_factor(self):
"""Get total subsampling factor."""
raise NotImplementedError(
"get_total_subsampling_factor method is not implemented"
)
import logging
import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np
class CTC(chainer.Chain):
"""Chainer implementation of ctc layer.
Args:
odim (int): The output dimension.
eprojs (int | None): Dimension of input vectors from encoder.
dropout_rate (float): Dropout rate.
"""
def __init__(self, odim, eprojs, dropout_rate):
super(CTC, self).__init__()
self.dropout_rate = dropout_rate
self.loss = None
with self.init_scope():
self.ctc_lo = L.Linear(eprojs, odim)
def __call__(self, hs, ys):
"""CTC forward.
Args:
hs (list of chainer.Variable | N-dimension array):
Input variable from encoder.
ys (list of chainer.Variable | N-dimension array):
Input variable of decoder.
Returns:
chainer.Variable: A variable holding a scalar value of the CTC loss.
"""
self.loss = None
ilens = [x.shape[0] for x in hs]
olens = [x.shape[0] for x in ys]
# zero padding for hs
y_hat = self.ctc_lo(
F.dropout(F.pad_sequence(hs), ratio=self.dropout_rate), n_batch_axes=2
)
y_hat = F.separate(y_hat, axis=1) # ilen list of batch x hdim
# zero padding for ys
y_true = F.pad_sequence(ys, padding=-1) # batch x olen
# get length info
input_length = chainer.Variable(self.xp.array(ilens, dtype=np.int32))
label_length = chainer.Variable(self.xp.array(olens, dtype=np.int32))
logging.info(
self.__class__.__name__ + " input lengths: " + str(input_length.data)
)
logging.info(
self.__class__.__name__ + " output lengths: " + str(label_length.data)
)
# get ctc loss
self.loss = F.connectionist_temporal_classification(
y_hat, y_true, 0, input_length, label_length
)
logging.info("ctc loss:" + str(self.loss.data))
return self.loss
def log_softmax(self, hs):
"""Log_softmax of frame activations.
Args:
hs (list of chainer.Variable | N-dimension array):
Input variable from encoder.
Returns:
chainer.Variable: A n-dimension float array.
"""
y_hat = self.ctc_lo(F.pad_sequence(hs), n_batch_axes=2)
return F.log_softmax(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape)
def ctc_for(args, odim):
"""Return the CTC layer corresponding to the args.
Args:
args (Namespace): The program arguments.
odim (int): The output dimension.
Returns:
The CTC module.
"""
ctc_type = args.ctc_type
if ctc_type == "builtin":
logging.info("Using chainer CTC implementation")
ctc = CTC(odim, args.eprojs, args.dropout_rate)
else:
raise ValueError('ctc_type must be "builtin": {}'.format(ctc_type))
return ctc
import chainer
import numpy
# from chainer.functions.connection import embed_id
from chainer import cuda, function_node, link, variable
from chainer.initializers import normal
from chainer.utils import type_check
"""Deterministic EmbedID link and function
copied from chainer/links/connection/embed_id.py
and chainer/functions/connection/embed_id.py,
and modified not to use atomicAdd operation
"""
class EmbedIDFunction(function_node.FunctionNode):
def __init__(self, ignore_label=None):
self.ignore_label = ignore_label
self._w_shape = None
def check_type_forward(self, in_types):
type_check.expect(in_types.size() == 2)
x_type, w_type = in_types
type_check.expect(
x_type.dtype.kind == "i",
x_type.ndim >= 1,
)
type_check.expect(w_type.dtype == numpy.float32, w_type.ndim == 2)
def forward(self, inputs):
self.retain_inputs((0,))
x, W = inputs
self._w_shape = W.shape
if not type_check.same_types(*inputs):
raise ValueError(
"numpy and cupy must not be used together\n"
"type(W): {0}, type(x): {1}".format(type(W), type(x))
)
xp = cuda.get_array_module(*inputs)
if chainer.is_debug():
valid_x = xp.logical_and(0 <= x, x < len(W))
if self.ignore_label is not None:
valid_x = xp.logical_or(valid_x, x == self.ignore_label)
if not valid_x.all():
raise ValueError(
"Each not ignored `x` value need to satisfy" "`0 <= x < len(W)`"
)
if self.ignore_label is not None:
mask = x == self.ignore_label
return (xp.where(mask[..., None], 0, W[xp.where(mask, 0, x)]),)
return (W[x],)
def backward(self, indexes, grad_outputs):
inputs = self.get_retained_inputs()
gW = EmbedIDGrad(self._w_shape, self.ignore_label).apply(inputs + grad_outputs)[
0
]
return None, gW
class EmbedIDGrad(function_node.FunctionNode):
def __init__(self, w_shape, ignore_label=None):
self.w_shape = w_shape
self.ignore_label = ignore_label
self._gy_shape = None
def forward(self, inputs):
self.retain_inputs((0,))
xp = cuda.get_array_module(*inputs)
x, gy = inputs
self._gy_shape = gy.shape
gW = xp.zeros(self.w_shape, dtype=gy.dtype)
if xp is numpy:
# It is equivalent to `numpy.add.at(gW, x, gy)` but ufunc.at is
# too slow.
for ix, igy in zip(x.ravel(), gy.reshape(x.size, -1)):
if ix == self.ignore_label:
continue
gW[ix] += igy
else:
"""
# original code based on cuda elementwise method
if self.ignore_label is None:
cuda.elementwise(
'T gy, S x, S n_out', 'raw T gW',
'ptrdiff_t w_ind[] = {x, i % n_out};'
'atomicAdd(&gW[w_ind], gy)',
'embed_id_bwd')(
gy, xp.expand_dims(x, -1), gW.shape[1], gW)
else:
cuda.elementwise(
'T gy, S x, S n_out, S ignore', 'raw T gW',
'''
if (x != ignore) {
ptrdiff_t w_ind[] = {x, i % n_out};
atomicAdd(&gW[w_ind], gy);
}
''',
'embed_id_bwd_ignore_label')(
gy, xp.expand_dims(x, -1), gW.shape[1],
self.ignore_label, gW)
"""
# EmbedID gradient alternative without atomicAdd, which simply
# creates a one-hot vector and applies dot product
xi = xp.zeros((x.size, len(gW)), dtype=numpy.float32)
idx = xp.arange(x.size, dtype=numpy.int32) * len(gW) + x.ravel()
xi.ravel()[idx] = 1.0
if self.ignore_label is not None:
xi[:, self.ignore_label] = 0.0
gW = xi.T.dot(gy.reshape(x.size, -1)).astype(gW.dtype, copy=False)
return (gW,)
def backward(self, indexes, grads):
xp = cuda.get_array_module(*grads)
x = self.get_retained_inputs()[0].data
ggW = grads[0]
if self.ignore_label is not None:
mask = x == self.ignore_label
# To prevent index out of bounds, we need to check if ignore_label
# is inside of W.
if not (0 <= self.ignore_label < self.w_shape[1]):
x = xp.where(mask, 0, x)
ggy = ggW[x]
if self.ignore_label is not None:
mask, zero, _ = xp.broadcast_arrays(
mask[..., None], xp.zeros((), "f"), ggy.data
)
ggy = chainer.functions.where(mask, zero, ggy)
return None, ggy
def embed_id(x, W, ignore_label=None):
r"""Efficient linear function for one-hot input.
This function implements so called *word embeddings*. It takes two
arguments: a set of IDs (words) ``x`` in :math:`B` dimensional integer
vector, and a set of all ID (word) embeddings ``W`` in :math:`V \\times d`
float32 matrix. It outputs :math:`B \\times d` matrix whose ``i``-th
column is the ``x[i]``-th column of ``W``.
This function is only differentiable on the input ``W``.
Args:
x (chainer.Variable | np.ndarray): Batch vectors of IDs. Each
element must be signed integer.
W (chainer.Variable | np.ndarray): Distributed representation
of each ID (a.k.a. word embeddings).
ignore_label (int): If ignore_label is an int value, i-th column
of return value is filled with 0.
Returns:
chainer.Variable: Embedded variable.
.. rubric:: :class:`~chainer.links.EmbedID`
Examples:
>>> x = np.array([2, 1]).astype('i')
>>> x
array([2, 1], dtype=int32)
>>> W = np.array([[0, 0, 0],
... [1, 1, 1],
... [2, 2, 2]]).astype('f')
>>> W
array([[ 0., 0., 0.],
[ 1., 1., 1.],
[ 2., 2., 2.]], dtype=float32)
>>> F.embed_id(x, W).data
array([[ 2., 2., 2.],
[ 1., 1., 1.]], dtype=float32)
>>> F.embed_id(x, W, ignore_label=1).data
array([[ 2., 2., 2.],
[ 0., 0., 0.]], dtype=float32)
"""
return EmbedIDFunction(ignore_label=ignore_label).apply((x, W))[0]
class EmbedID(link.Link):
"""Efficient linear layer for one-hot input.
This is a link that wraps the :func:`~chainer.functions.embed_id` function.
This link holds the ID (word) embedding matrix ``W`` as a parameter.
Args:
in_size (int): Number of different identifiers (a.k.a. vocabulary size).
out_size (int): Output dimension.
initialW (Initializer): Initializer to initialize the weight.
ignore_label (int): If `ignore_label` is an int value, i-th column of
return value is filled with 0.
.. rubric:: :func:`~chainer.functions.embed_id`
Attributes:
W (~chainer.Variable): Embedding parameter matrix.
Examples:
>>> W = np.array([[0, 0, 0],
... [1, 1, 1],
... [2, 2, 2]]).astype('f')
>>> W
array([[ 0., 0., 0.],
[ 1., 1., 1.],
[ 2., 2., 2.]], dtype=float32)
>>> l = L.EmbedID(W.shape[0], W.shape[1], initialW=W)
>>> x = np.array([2, 1]).astype('i')
>>> x
array([2, 1], dtype=int32)
>>> y = l(x)
>>> y.data
array([[ 2., 2., 2.],
[ 1., 1., 1.]], dtype=float32)
"""
ignore_label = None
def __init__(self, in_size, out_size, initialW=None, ignore_label=None):
super(EmbedID, self).__init__()
self.ignore_label = ignore_label
with self.init_scope():
if initialW is None:
initialW = normal.Normal(1.0)
self.W = variable.Parameter(initialW, (in_size, out_size))
def __call__(self, x):
"""Extracts the word embedding of given IDs.
Args:
x (chainer.Variable): Batch vectors of IDs.
Returns:
chainer.Variable: Batch of corresponding embeddings.
"""
return embed_id(x, self.W, ignore_label=self.ignore_label)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment