Commit 60a2c57a authored by sunzhq2's avatar sunzhq2 Committed by xuxo
Browse files

update conformer

parent 4a699441
# encoding: utf-8
"""Class Declaration of Transformer's Input layers."""
import logging
import chainer
import chainer.functions as F
import chainer.links as L
import numpy as np
from espnet.nets.chainer_backend.transformer.embedding import PositionalEncoding
class Conv2dSubsampling(chainer.Chain):
"""Convolutional 2D subsampling (to 1/4 length).
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
"""
def __init__(
self, channels, idim, dims, dropout=0.1, initialW=None, initial_bias=None
):
"""Initialize Conv2dSubsampling."""
super(Conv2dSubsampling, self).__init__()
self.dropout = dropout
with self.init_scope():
# Standard deviation for Conv2D with 1 channel and kernel 3 x 3.
n = 1 * 3 * 3
stvd = 1.0 / np.sqrt(n)
self.conv1 = L.Convolution2D(
1,
channels,
3,
stride=2,
pad=1,
initialW=initialW(scale=stvd),
initial_bias=initial_bias(scale=stvd),
)
n = channels * 3 * 3
stvd = 1.0 / np.sqrt(n)
self.conv2 = L.Convolution2D(
channels,
channels,
3,
stride=2,
pad=1,
initialW=initialW(scale=stvd),
initial_bias=initial_bias(scale=stvd),
)
stvd = 1.0 / np.sqrt(dims)
self.out = L.Linear(
idim,
dims,
initialW=initialW(scale=stvd),
initial_bias=initial_bias(scale=stvd),
)
self.pe = PositionalEncoding(dims, dropout)
def forward(self, xs, ilens):
"""Subsample x.
:param chainer.Variable x: input tensor
:return: subsampled x and mask
"""
xs = self.xp.array(xs[:, None])
xs = F.relu(self.conv1(xs))
xs = F.relu(self.conv2(xs))
batch, _, length, _ = xs.shape
xs = self.out(F.swapaxes(xs, 1, 2).reshape(batch * length, -1))
xs = self.pe(xs.reshape(batch, length, -1))
# change ilens accordingly
ilens = np.ceil(np.array(ilens, dtype=np.float32) / 2).astype(np.int64)
ilens = np.ceil(np.array(ilens, dtype=np.float32) / 2).astype(np.int64)
return xs, ilens
class LinearSampling(chainer.Chain):
"""Linear 1D subsampling.
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
"""
def __init__(self, idim, dims, dropout=0.1, initialW=None, initial_bias=None):
"""Initialize LinearSampling."""
super(LinearSampling, self).__init__()
stvd = 1.0 / np.sqrt(dims)
self.dropout = dropout
with self.init_scope():
self.linear = L.Linear(
idim,
dims,
initialW=initialW(scale=stvd),
initial_bias=initial_bias(scale=stvd),
)
self.pe = PositionalEncoding(dims, dropout)
def forward(self, xs, ilens):
"""Subsample x.
:param chainer.Variable x: input tensor
:return: subsampled x and mask
"""
logging.info(xs.shape)
xs = self.linear(xs, n_batch_axes=2)
logging.info(xs.shape)
xs = self.pe(xs)
return xs, ilens
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Class Declaration of Transformer's Training Subprocess."""
import collections
import logging
import math
import numpy as np
from chainer import cuda
from chainer import functions as F
from chainer import training
from chainer.training import extension
from chainer.training.updaters.multiprocess_parallel_updater import (
gather_grads,
gather_params,
scatter_grads,
)
# copied from https://github.com/chainer/chainer/blob/master/chainer/optimizer.py
def sum_sqnorm(arr):
"""Calculate the norm of the array.
Args:
arr (numpy.ndarray)
Returns:
Float: Sum of the norm calculated from the given array.
"""
sq_sum = collections.defaultdict(float)
for x in arr:
with cuda.get_device_from_array(x) as dev:
if x is not None:
x = x.ravel()
s = x.dot(x)
sq_sum[int(dev)] += s
return sum([float(i) for i in sq_sum.values()])
class CustomUpdater(training.StandardUpdater):
"""Custom updater for chainer.
Args:
train_iter (iterator | dict[str, iterator]): Dataset iterator for the
training dataset. It can also be a dictionary that maps strings to
iterators. If this is just an iterator, then the iterator is
registered by the name ``'main'``.
optimizer (optimizer | dict[str, optimizer]): Optimizer to update
parameters. It can also be a dictionary that maps strings to
optimizers. If this is just an optimizer, then the optimizer is
registered by the name ``'main'``.
converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter
function to build input arrays. Each batch extracted by the main
iterator and the ``device`` option are passed to this function.
:func:`chainer.dataset.concat_examples` is used by default.
device (int or dict): The destination device info to send variables. In the
case of cpu or single gpu, `device=-1 or 0`, respectively.
In the case of multi-gpu, `device={"main":0, "sub_1": 1, ...}`.
accum_grad (int):The number of gradient accumulation. if set to 2, the network
parameters will be updated once in twice,
i.e. actual batchsize will be doubled.
"""
def __init__(self, train_iter, optimizer, converter, device, accum_grad=1):
"""Initialize Custom Updater."""
super(CustomUpdater, self).__init__(
train_iter, optimizer, converter=converter, device=device
)
self.accum_grad = accum_grad
self.forward_count = 0
self.start = True
self.device = device
logging.debug("using custom converter for transformer")
# The core part of the update routine can be customized by overriding.
def update_core(self):
"""Process main update routine for Custom Updater."""
train_iter = self.get_iterator("main")
optimizer = self.get_optimizer("main")
# Get batch and convert into variables
batch = train_iter.next()
x = self.converter(batch, self.device)
if self.start:
optimizer.target.cleargrads()
self.start = False
# Compute the loss at this time step and accumulate it
loss = optimizer.target(*x) / self.accum_grad
loss.backward() # Backprop
self.forward_count += 1
if self.forward_count != self.accum_grad:
return
self.forward_count = 0
# compute the gradient norm to check if it is normal or not
grad_norm = np.sqrt(
sum_sqnorm([p.grad for p in optimizer.target.params(False)])
)
logging.info("grad norm={}".format(grad_norm))
if math.isnan(grad_norm):
logging.warning("grad norm is nan. Do not update model.")
else:
optimizer.update()
optimizer.target.cleargrads() # Clear the parameter gradients
def update(self):
"""Update step for Custom Updater."""
self.update_core()
if self.forward_count == 0:
self.iteration += 1
class CustomParallelUpdater(training.updaters.MultiprocessParallelUpdater):
"""Custom Parallel Updater for chainer.
Defines the main update routine.
Args:
train_iter (iterator | dict[str, iterator]): Dataset iterator for the
training dataset. It can also be a dictionary that maps strings to
iterators. If this is just an iterator, then the iterator is
registered by the name ``'main'``.
optimizer (optimizer | dict[str, optimizer]): Optimizer to update
parameters. It can also be a dictionary that maps strings to
optimizers. If this is just an optimizer, then the optimizer is
registered by the name ``'main'``.
converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter
function to build input arrays. Each batch extracted by the main
iterator and the ``device`` option are passed to this function.
:func:`chainer.dataset.concat_examples` is used by default.
device (torch.device): Device to which the training data is sent. Negative value
indicates the host memory (CPU).
accum_grad (int):The number of gradient accumulation. if set to 2, the network
parameters will be updated once in twice,
i.e. actual batchsize will be doubled.
"""
def __init__(self, train_iters, optimizer, converter, devices, accum_grad=1):
"""Initialize custom parallel updater."""
from cupy.cuda import nccl
super(CustomParallelUpdater, self).__init__(
train_iters, optimizer, converter=converter, devices=devices
)
self.accum_grad = accum_grad
self.forward_count = 0
self.nccl = nccl
logging.debug("using custom parallel updater for transformer")
# The core part of the update routine can be customized by overriding.
def update_core(self):
"""Process main update routine for Custom Parallel Updater."""
self.setup_workers()
self._send_message(("update", None))
with cuda.Device(self._devices[0]):
# For reducing memory
optimizer = self.get_optimizer("main")
batch = self.get_iterator("main").next()
x = self.converter(batch, self._devices[0])
loss = self._master(*x) / self.accum_grad
loss.backward()
# NCCL: reduce grads
null_stream = cuda.Stream.null
if self.comm is not None:
gg = gather_grads(self._master)
self.comm.reduce(
gg.data.ptr,
gg.data.ptr,
gg.size,
self.nccl.NCCL_FLOAT,
self.nccl.NCCL_SUM,
0,
null_stream.ptr,
)
scatter_grads(self._master, gg)
del gg
# update parameters
self.forward_count += 1
if self.forward_count != self.accum_grad:
return
self.forward_count = 0
# check gradient value
grad_norm = np.sqrt(
sum_sqnorm([p.grad for p in optimizer.target.params(False)])
)
logging.info("grad norm={}".format(grad_norm))
# update
if math.isnan(grad_norm):
logging.warning("grad norm is nan. Do not update model.")
else:
optimizer.update()
self._master.cleargrads()
if self.comm is not None:
gp = gather_params(self._master)
self.comm.bcast(
gp.data.ptr, gp.size, self.nccl.NCCL_FLOAT, 0, null_stream.ptr
)
def update(self):
"""Update step for Custom Parallel Updater."""
self.update_core()
if self.forward_count == 0:
self.iteration += 1
class VaswaniRule(extension.Extension):
"""Trainer extension to shift an optimizer attribute magically by Vaswani.
Args:
attr (str): Name of the attribute to shift.
rate (float): Rate of the exponential shift. This value is multiplied
to the attribute at each call.
init (float): Initial value of the attribute. If it is ``None``, the
extension extracts the attribute at the first call and uses it as
the initial value.
target (float): Target value of the attribute. If the attribute reaches
this value, the shift stops.
optimizer (~chainer.Optimizer): Target optimizer to adjust the
attribute. If it is ``None``, the main optimizer of the updater is
used.
"""
def __init__(
self,
attr,
d,
warmup_steps=4000,
init=None,
target=None,
optimizer=None,
scale=1.0,
):
"""Initialize Vaswani rule extension."""
self._attr = attr
self._d_inv05 = d ** (-0.5) * scale
self._warmup_steps_inv15 = warmup_steps ** (-1.5)
self._init = init
self._target = target
self._optimizer = optimizer
self._t = 0
self._last_value = None
def initialize(self, trainer):
"""Initialize Optimizer values."""
optimizer = self._get_optimizer(trainer)
# ensure that _init is set
if self._init is None:
self._init = self._d_inv05 * (1.0 * self._warmup_steps_inv15)
if self._last_value is not None: # resuming from a snapshot
self._update_value(optimizer, self._last_value)
else:
self._update_value(optimizer, self._init)
def __call__(self, trainer):
"""Forward extension."""
self._t += 1
optimizer = self._get_optimizer(trainer)
value = self._d_inv05 * min(
self._t ** (-0.5), self._t * self._warmup_steps_inv15
)
self._update_value(optimizer, value)
def serialize(self, serializer):
"""Serialize extension."""
self._t = serializer("_t", self._t)
self._last_value = serializer("_last_value", self._last_value)
def _get_optimizer(self, trainer):
"""Obtain optimizer from trainer."""
return self._optimizer or trainer.updater.get_optimizer("main")
def _update_value(self, optimizer, value):
"""Update requested variable values."""
setattr(optimizer, self._attr, value)
self._last_value = value
class CustomConverter(object):
"""Custom Converter.
Args:
subsampling_factor (int): The subsampling factor.
"""
def __init__(self):
"""Initialize subsampling."""
pass
def __call__(self, batch, device):
"""Perform subsampling.
Args:
batch (list): Batch that will be sabsampled.
device (chainer.backend.Device): CPU or GPU device.
Returns:
chainer.Variable: xp.array that are padded and subsampled from batch.
xp.array: xp.array of the length of the mini-batches.
chainer.Variable: xp.array that are padded and subsampled from batch.
"""
# For transformer, data is processed in CPU.
# batch should be located in list
assert len(batch) == 1
xs, ys = batch[0]
xs = F.pad_sequence(xs, padding=-1).data
# get batch of lengths of input sequences
ilens = np.array([x.shape[0] for x in xs], dtype=np.int32)
return xs, ilens, ys
#!/usr/bin/env python3
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import numpy as np
import torch
class CTCPrefixScoreTH(object):
"""Batch processing of CTCPrefixScore
which is based on Algorithm 2 in WATANABE et al.
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
but extended to efficiently compute the label probablities for multiple
hypotheses simultaneously
See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
"""
def __init__(self, x, xlens, blank, eos, margin=0):
"""Construct CTC prefix scorer
:param torch.Tensor x: input label posterior sequences (B, T, O)
:param torch.Tensor xlens: input lengths (B,)
:param int blank: blank label id
:param int eos: end-of-sequence id
:param int margin: margin parameter for windowing (0 means no windowing)
"""
# In the comment lines,
# we assume T: input_length, B: batch size, W: beam width, O: output dim.
self.logzero = -10000000000.0
self.blank = blank
self.eos = eos
self.batch = x.size(0)
self.input_length = x.size(1)
self.odim = x.size(2)
self.dtype = x.dtype
self.device = (
torch.device("cuda:%d" % x.get_device())
if x.is_cuda
else torch.device("cpu")
)
# Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops
for i, l in enumerate(xlens):
if l < self.input_length:
x[i, l:, :] = self.logzero
x[i, l:, blank] = 0
# Reshape input x
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
self.x = torch.stack([xn, xb]) # (2, T, B, O)
self.end_frames = torch.as_tensor(xlens) - 1
# Setup CTC windowing
self.margin = margin
if margin > 0:
self.frame_ids = torch.arange(
self.input_length, dtype=self.dtype, device=self.device
)
# Base indices for index conversion
self.idx_bh = None
self.idx_b = torch.arange(self.batch, device=self.device)
self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
def __call__(self, y, state, scoring_ids=None, att_w=None):
"""Compute CTC prefix scores for next labels
:param list y: prefix label sequences
:param tuple state: previous CTC state
:param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O)
:param torch.Tensor att_w: attention weights to decide CTC window
:return new_state, ctc_local_scores (BW, O)
"""
output_length = len(y[0]) - 1 # ignore sos
last_ids = [yi[-1] for yi in y] # last output label ids
n_bh = len(last_ids) # batch * hyps
n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
# prepare state info
if state is None:
r_prev = torch.full(
(self.input_length, 2, self.batch, n_hyps),
self.logzero,
dtype=self.dtype,
device=self.device,
)
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
r_prev = r_prev.view(-1, 2, n_bh)
s_prev = 0.0
f_min_prev = 0
f_max_prev = 1
else:
r_prev, s_prev, f_min_prev, f_max_prev = state
# select input dimensions for scoring
if self.scoring_num > 0:
scoring_idmap = torch.full(
(n_bh, self.odim), -1, dtype=torch.long, device=self.device
)
snum = self.scoring_num
if self.idx_bh is None or n_bh > len(self.idx_bh):
self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(
snum, device=self.device
)
scoring_idx = (
scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)
).view(-1)
x_ = torch.index_select(
self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx
).view(2, -1, n_bh, snum)
else:
scoring_ids = None
scoring_idmap = None
snum = self.odim
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
r = torch.full(
(self.input_length, 2, n_bh, snum),
self.logzero,
dtype=self.dtype,
device=self.device,
)
if output_length == 0:
r[0, 0] = x_[0, 0]
r_sum = torch.logsumexp(r_prev, 1)
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
if scoring_ids is not None:
for idx in range(n_bh):
pos = scoring_idmap[idx, last_ids[idx]]
if pos >= 0:
log_phi[:, idx, pos] = r_prev[:, 1, idx]
else:
for idx in range(n_bh):
log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
# decide start and end frames based on attention weights
if att_w is not None and self.margin > 0:
f_arg = torch.matmul(att_w, self.frame_ids)
f_min = max(int(f_arg.min().cpu()), f_min_prev)
f_max = max(int(f_arg.max().cpu()), f_max_prev)
start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
end = min(f_max + self.margin, self.input_length)
else:
f_min = f_max = 0
start = max(output_length, 1)
end = self.input_length
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
for t in range(start, end):
rp = r[t - 1]
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(
2, 2, n_bh, snum
)
r[t] = torch.logsumexp(rr, 1) + x_[:, t]
# compute log prefix probabilities log(psi)
log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
if scoring_ids is not None:
log_psi = torch.full(
(n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device
)
log_psi_ = torch.logsumexp(
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
dim=0,
)
for si in range(n_bh):
log_psi[si, scoring_ids[si]] = log_psi_[si]
else:
log_psi = torch.logsumexp(
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
dim=0,
)
for si in range(n_bh):
log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
# exclude blank probs
log_psi[:, self.blank] = self.logzero
return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap)
def index_select_state(self, state, best_ids):
"""Select CTC states according to best ids
:param state : CTC state
:param best_ids : index numbers selected by beam pruning (B, W)
:return selected_state
"""
r, s, f_min, f_max, scoring_idmap = state
# convert ids to BHO space
n_bh = len(s)
n_hyps = n_bh // self.batch
vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
# select hypothesis scores
s_new = torch.index_select(s.view(-1), 0, vidx)
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
# convert ids to BHS space (S: scoring_num)
if scoring_idmap is not None:
snum = self.scoring_num
hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(
-1
)
label_ids = torch.fmod(best_ids, self.odim).view(-1)
score_idx = scoring_idmap[hyp_idx, label_ids]
score_idx[score_idx == -1] = 0
vidx = score_idx + hyp_idx * snum
else:
snum = self.odim
# select forward probabilities
r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(
-1, 2, n_bh
)
return r_new, s_new, f_min, f_max
def extend_prob(self, x):
"""Extend CTC prob.
:param torch.Tensor x: input label posterior sequences (B, T, O)
"""
if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
# Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops
xlens = [x.size(1)]
for i, l in enumerate(xlens):
if l < self.input_length:
x[i, l:, :] = self.logzero
x[i, l:, self.blank] = 0
tmp_x = self.x
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
self.x = torch.stack([xn, xb]) # (2, T, B, O)
self.x[:, : tmp_x.shape[1], :, :] = tmp_x
self.input_length = x.size(1)
self.end_frames = torch.as_tensor(xlens) - 1
def extend_state(self, state):
"""Compute CTC prefix state.
:param state : CTC state
:return ctc_state
"""
if state is None:
# nothing to do
return state
else:
r_prev, s_prev, f_min_prev, f_max_prev = state
r_prev_new = torch.full(
(self.input_length, 2),
self.logzero,
dtype=self.dtype,
device=self.device,
)
start = max(r_prev.shape[0], 1)
r_prev_new[0:start] = r_prev
for t in range(start, self.input_length):
r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
return (r_prev_new, s_prev, f_min_prev, f_max_prev)
class CTCPrefixScore(object):
"""Compute CTC label sequence scores
which is based on Algorithm 2 in WATANABE et al.
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
but extended to efficiently compute the probablities of multiple labels
simultaneously
"""
def __init__(self, x, blank, eos, xp):
self.xp = xp
self.logzero = -10000000000.0
self.blank = blank
self.eos = eos
self.input_length = len(x)
self.x = x
def initial_state(self):
"""Obtain an initial CTC state
:return: CTC state
"""
# initial CTC state is made of a frame x 2 tensor that corresponds to
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
# superscripts n and b (non-blank and blank), respectively.
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
r[0, 1] = self.x[0, self.blank]
for i in range(1, self.input_length):
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
return r
def __call__(self, y, cs, r_prev):
"""Compute CTC prefix scores for next labels
:param y : prefix label sequence
:param cs : array of next labels
:param r_prev: previous CTC state
:return ctc_scores, ctc_states
"""
# initialize CTC states
output_length = len(y) - 1 # ignore sos
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
# that corresponds to r_t^n(h) and r_t^b(h).
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
xs = self.x[:, cs]
if output_length == 0:
r[0, 0] = xs[0]
r[0, 1] = self.logzero
else:
r[output_length - 1] = self.logzero
# prepare forward probabilities for the last label
r_sum = self.xp.logaddexp(
r_prev[:, 0], r_prev[:, 1]
) # log(r_t^n(g) + r_t^b(g))
last = y[-1]
if output_length > 0 and last in cs:
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
for i in range(len(cs)):
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
else:
log_phi = r_sum
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
# and log prefix probabilities log(psi)
start = max(output_length, 1)
log_psi = r[start - 1, 0]
for t in range(start, self.input_length):
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
r[t, 1] = (
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
)
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
# get P(...eos|X) that ends with the prefix itself
eos_pos = self.xp.where(cs == self.eos)[0]
if len(eos_pos) > 0:
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
# exclude blank probs
blank_pos = self.xp.where(cs == self.blank)[0]
if len(blank_pos) > 0:
log_psi[blank_pos] = self.logzero
# return the log prefix probability and CTC states, where the label axis
# of the CTC states is moved to the first axis to slice it easily
return log_psi, self.xp.rollaxis(r, 2)
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Common functions for ASR."""
import json
import logging
import sys
from itertools import groupby
import numpy as np
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection.
described in Eq. (50) of S. Watanabe et al
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
:param ended_hyps:
:param i:
:param M:
:param D_end:
:return:
"""
if len(ended_hyps) == 0:
return False
count = 0
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
for m in range(M):
# get ended_hyps with their length is i - m
hyp_length = i - m
hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
if len(hyps_same_length) > 0:
best_hyp_same_length = sorted(
hyps_same_length, key=lambda x: x["score"], reverse=True
)[0]
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
count += 1
if count == M:
return True
else:
return False
# TODO(takaaki-hori): add different smoothing methods
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
"""Obtain label distribution for loss smoothing.
:param odim:
:param lsm_type:
:param blank:
:param transcript:
:return:
"""
if transcript is not None:
with open(transcript, "rb") as f:
trans_json = json.load(f)["utts"]
if lsm_type == "unigram":
assert transcript is not None, (
"transcript is required for %s label smoothing" % lsm_type
)
labelcount = np.zeros(odim)
for k, v in trans_json.items():
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
# to avoid an error when there is no text in an uttrance
if len(ids) > 0:
labelcount[ids] += 1
labelcount[odim - 1] = len(transcript) # count <eos>
labelcount[labelcount == 0] = 1 # flooring
labelcount[blank] = 0 # remove counts for blank
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
else:
logging.error("Error: unexpected label smoothing type: %s" % lsm_type)
sys.exit()
return labeldist
def get_vgg2l_odim(idim, in_channel=3, out_channel=128):
"""Return the output size of the VGG frontend.
:param in_channel: input channel size
:param out_channel: output channel size
:return: output size
:rtype int
"""
idim = idim / in_channel
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
return int(idim) * out_channel # numer of channels
class ErrorCalculator(object):
"""Calculate CER and WER for E2E_ASR and CTC models during training.
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list:
:param sym_space:
:param sym_blank:
:return:
"""
def __init__(
self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
):
"""Construct an ErrorCalculator object."""
super(ErrorCalculator, self).__init__()
self.report_cer = report_cer
self.report_wer = report_wer
self.char_list = char_list
self.space = sym_space
self.blank = sym_blank
# NOTE (Shih-Lun): else case is for OpenAI Whisper ASR model,
# which doesn't use <blank> token
if self.blank in self.char_list:
self.idx_blank = self.char_list.index(self.blank)
else:
self.idx_blank = None
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
self.idx_space = None
def __call__(self, ys_hat, ys_pad, is_ctc=False):
"""Calculate sentence-level WER/CER score.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:param bool is_ctc: calculate CER score for CTC
:return: sentence-level WER score
:rtype float
:return: sentence-level CER score
:rtype float
"""
cer, wer = None, None
if is_ctc:
return self.calculate_cer_ctc(ys_hat, ys_pad)
elif not self.report_cer and not self.report_wer:
return cer, wer
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
if self.report_cer:
cer = self.calculate_cer(seqs_hat, seqs_true)
if self.report_wer:
wer = self.calculate_wer(seqs_hat, seqs_true)
return cer, wer
def calculate_cer_ctc(self, ys_hat, ys_pad):
"""Calculate sentence-level CER score for CTC.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:return: average sentence-level CER score
:rtype float
"""
import editdistance
cers, char_ref_lens = [], []
for i, y in enumerate(ys_hat):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat, seq_true = [], []
for idx in y_hat:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_hat.append(self.char_list[int(idx)])
for idx in y_true:
idx = int(idx)
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
seq_true.append(self.char_list[int(idx)])
hyp_chars = "".join(seq_hat)
ref_chars = "".join(seq_true)
if len(ref_chars) > 0:
cers.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
return cer_ctc
def convert_to_char(self, ys_hat, ys_pad):
"""Convert index to character.
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
:param torch.Tensor seqs_true: reference (batch, seqlen)
:return: token list of prediction
:rtype list
:return: token list of reference
:rtype list
"""
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
# NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
return seqs_hat, seqs_true
def calculate_cer(self, seqs_hat, seqs_true):
"""Calculate sentence-level CER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level CER score
:rtype float
"""
import editdistance
char_eds, char_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_chars = seq_hat_text.replace(" ", "")
ref_chars = seq_true_text.replace(" ", "")
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
return float(sum(char_eds)) / sum(char_ref_lens)
def calculate_wer(self, seqs_hat, seqs_true):
"""Calculate sentence-level WER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level WER score
:rtype float
"""
import editdistance
word_eds, word_ref_lens = [], []
for i, seq_hat_text in enumerate(seqs_hat):
seq_true_text = seqs_true[i]
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Common functions for ST and MT."""
import nltk
import numpy as np
class ErrorCalculator(object):
"""Calculate BLEU for ST and MT models during training.
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list: vocabulary list
:param sym_space: space symbol
:param sym_pad: pad symbol
:param report_bleu: report BLUE score if True
"""
def __init__(self, char_list, sym_space, sym_pad, report_bleu=False):
"""Construct an ErrorCalculator object."""
super(ErrorCalculator, self).__init__()
self.char_list = char_list
self.space = sym_space
self.pad = sym_pad
self.report_bleu = report_bleu
if self.space in self.char_list:
self.idx_space = self.char_list.index(self.space)
else:
self.idx_space = None
def __call__(self, ys_hat, ys_pad):
"""Calculate corpus-level BLEU score.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:return: corpus-level BLEU score in a mini-batch
:rtype float
"""
bleu = None
if not self.report_bleu:
return bleu
bleu = self.calculate_corpus_bleu(ys_hat, ys_pad)
return bleu
def calculate_corpus_bleu(self, ys_hat, ys_pad):
"""Calculate corpus-level BLEU score in a mini-batch.
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
:param torch.Tensor seqs_true: reference (batch, seqlen)
:return: corpus-level BLEU score
:rtype float
"""
seqs_hat, seqs_true = [], []
for i, y_hat in enumerate(ys_hat):
y_true = ys_pad[i]
eos_true = np.where(y_true == -1)[0]
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
# NOTE: padding index (-1) in y_true is used to pad y_hat
# because y_hats is not padded with -1
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.pad, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true], seqs_hat)
return bleu * 100
"""Language model interface."""
import argparse
from espnet.nets.scorer_interface import ScorerInterface
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.fill_missing_args import fill_missing_args
class LMInterface(ScorerInterface):
"""LM Interface for ESPnet model implementation."""
@staticmethod
def add_arguments(parser):
"""Add arguments to command line argument parser."""
return parser
@classmethod
def build(cls, n_vocab: int, **kwargs):
"""Initialize this class with python-level args.
Args:
idim (int): The number of vocabulary.
Returns:
LMinterface: A new instance of LMInterface.
"""
# local import to avoid cyclic import in lm_train
from espnet.bin.lm_train import get_parser
def wrap(parser):
return get_parser(parser, required=False)
args = argparse.Namespace(**kwargs)
args = fill_missing_args(args, wrap)
args = fill_missing_args(args, cls.add_arguments)
return cls(n_vocab, args)
def forward(self, x, t):
"""Compute LM loss value from buffer sequences.
Args:
x (torch.Tensor): Input ids. (batch, len)
t (torch.Tensor): Target ids. (batch, len)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
raise NotImplementedError("forward method is not implemented")
predefined_lms = {
"pytorch": {
"default": "espnet.nets.pytorch_backend.lm.default:DefaultRNNLM",
"seq_rnn": "espnet.nets.pytorch_backend.lm.seq_rnn:SequentialRNNLM",
"transformer": "espnet.nets.pytorch_backend.lm.transformer:TransformerLM",
},
"chainer": {"default": "espnet.lm.chainer_backend.lm:DefaultRNNLM"},
}
def dynamic_import_lm(module, backend):
"""Import LM class dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_lms`
backend (str): NN backend. e.g., pytorch, chainer
Returns:
type: LM class
"""
model_class = dynamic_import(module, predefined_lms.get(backend, dict()))
assert issubclass(
model_class, LMInterface
), f"{module} does not implement LMInterface"
return model_class
"""MT Interface module."""
import argparse
from espnet.bin.asr_train import get_parser
from espnet.utils.fill_missing_args import fill_missing_args
class MTInterface:
"""MT Interface for ESPnet model implementation."""
@staticmethod
def add_arguments(parser):
"""Add arguments to parser."""
return parser
@classmethod
def build(cls, idim: int, odim: int, **kwargs):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
def wrap(parser):
return get_parser(parser, required=False)
args = argparse.Namespace(**kwargs)
args = fill_missing_args(args, wrap)
args = fill_missing_args(args, cls.add_arguments)
return cls(idim, odim, args)
def forward(self, xs, ilens, ys):
"""Compute loss for training.
:param xs:
For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim)
For chainer, list of source sequences chainer.Variable
:param ilens: batch of lengths of source sequences (B)
For pytorch, torch.Tensor
For chainer, list of int
:param ys:
For pytorch, batch of padded source sequences torch.Tensor (B, Lmax)
For chainer, list of source sequences chainer.Variable
:return: loss value
:rtype: torch.Tensor for pytorch, chainer.Variable for chainer
"""
raise NotImplementedError("forward method is not implemented")
def translate(self, x, trans_args, char_list=None, rnnlm=None):
"""Translate x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("translate method is not implemented")
def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace trans_args: argument namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
def calculate_all_attentions(self, xs, ilens, ys):
"""Calculate attention.
:param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise NotImplementedError("calculate_all_attentions method is not implemented")
@property
def attention_plot_class(self):
"""Get attention plot class."""
from espnet.asr.asr_utils import PlotAttentionReport
return PlotAttentionReport
# Copyright 2020 Hirofumi Inaguma
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Conformer common arguments."""
import logging
from distutils.util import strtobool
def add_arguments_conformer_common(group):
"""Add Transformer common arguments."""
group.add_argument(
"--transformer-encoder-pos-enc-layer-type",
type=str,
default="abs_pos",
choices=["abs_pos", "scaled_abs_pos", "rel_pos"],
help="Transformer encoder positional encoding layer type",
)
group.add_argument(
"--transformer-encoder-activation-type",
type=str,
default="swish",
choices=["relu", "hardtanh", "selu", "swish"],
help="Transformer encoder activation function type",
)
group.add_argument(
"--macaron-style",
default=False,
type=strtobool,
help="Whether to use macaron style for positionwise layer",
)
# Attention
group.add_argument(
"--zero-triu",
default=False,
type=strtobool,
help="If true, zero the uppper triangular part of attention matrix.",
)
# Relative positional encoding
group.add_argument(
"--rel-pos-type",
type=str,
default="legacy",
choices=["legacy", "latest"],
help="Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816.",
)
# CNN module
group.add_argument(
"--use-cnn-module",
default=False,
type=strtobool,
help="Use convolution module or not",
)
group.add_argument(
"--cnn-module-kernel",
default=31,
type=int,
help="Kernel size of convolution module.",
)
return group
def verify_rel_pos_type(args):
"""Verify the relative positional encoding type for compatibility.
Args:
args (Namespace): original arguments
Returns:
args (Namespace): modified arguments
"""
rel_pos_type = getattr(args, "rel_pos_type", None)
if rel_pos_type is None or rel_pos_type == "legacy":
if args.transformer_encoder_pos_enc_layer_type == "rel_pos":
args.transformer_encoder_pos_enc_layer_type = "legacy_rel_pos"
logging.warning(
"Using legacy_rel_pos and it will be deprecated in the future."
)
if args.transformer_encoder_selfattn_layer_type == "rel_selfattn":
args.transformer_encoder_selfattn_layer_type = "legacy_rel_selfattn"
logging.warning(
"Using legacy_rel_selfattn and it will be deprecated in the future."
)
return args
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 21 16:57:31 2021.
@author: Keqi Deng (UCAS)
"""
import torch
from torch import nn
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
class ContextualBlockEncoderLayer(nn.Module):
"""Contexutal Block Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
total_layer_num (int): Total number of layers
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn,
feed_forward,
feed_forward_macaron,
conv_module,
dropout_rate,
total_layer_num,
normalize_before=True,
concat_after=False,
):
"""Construct an EncoderLayer object."""
super(ContextualBlockEncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
if feed_forward_macaron is not None:
self.norm_ff_macaron = LayerNorm(size)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = LayerNorm(size) # for the CNN module
self.norm_final = LayerNorm(size) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
self.total_layer_num = total_layer_num
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
def forward(
self,
x,
mask,
infer_mode=False,
past_ctx=None,
next_ctx=None,
is_short_segment=False,
layer_idx=0,
cache=None,
):
"""Calculate forward propagation."""
if self.training or not infer_mode:
return self.forward_train(x, mask, past_ctx, next_ctx, layer_idx, cache)
else:
return self.forward_infer(
x, mask, past_ctx, next_ctx, is_short_segment, layer_idx, cache
)
def forward_train(
self, x, mask, past_ctx=None, next_ctx=None, layer_idx=0, cache=None
):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
past_ctx (torch.Tensor): Previous contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
cur_ctx (torch.Tensor): Current contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
layer_idx (int): layer index number
"""
nbatch = x.size(0)
nblock = x.size(1)
if past_ctx is not None:
if next_ctx is None:
# store all context vectors in one tensor
next_ctx = past_ctx.new_zeros(
nbatch, nblock, self.total_layer_num, x.size(-1)
)
else:
x[:, :, 0] = past_ctx[:, :, layer_idx]
# reshape ( nbatch, nblock, block_size + 2, dim )
# -> ( nbatch * nblock, block_size + 2, dim )
x = x.view(-1, x.size(-2), x.size(-1))
if mask is not None:
mask = mask.view(-1, mask.size(-2), mask.size(-1))
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
residual = x
if self.normalize_before:
x = self.norm1(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(self.self_attn(x_q, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
# convolution module
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x = residual + self.dropout(self.conv_module(x))
if not self.normalize_before:
x = self.norm_conv(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if self.conv_module is not None:
x = self.norm_final(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
layer_idx += 1
# reshape ( nbatch * nblock, block_size + 2, dim )
# -> ( nbatch, nblock, block_size + 2, dim )
x = x.view(nbatch, -1, x.size(-2), x.size(-1)).squeeze(1)
if mask is not None:
mask = mask.view(nbatch, -1, mask.size(-2), mask.size(-1)).squeeze(1)
if next_ctx is not None and layer_idx < self.total_layer_num:
next_ctx[:, 0, layer_idx, :] = x[:, 0, -1, :]
next_ctx[:, 1:, layer_idx, :] = x[:, 0:-1, -1, :]
return x, mask, False, next_ctx, next_ctx, False, layer_idx
def forward_infer(
self,
x,
mask,
past_ctx=None,
next_ctx=None,
is_short_segment=False,
layer_idx=0,
cache=None,
):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
past_ctx (torch.Tensor): Previous contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, 1, time).
cur_ctx (torch.Tensor): Current contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
layer_idx (int): layer index number
"""
nbatch = x.size(0)
nblock = x.size(1)
# if layer_idx == 0, next_ctx has to be None
if layer_idx == 0:
assert next_ctx is None
next_ctx = x.new_zeros(nbatch, self.total_layer_num, x.size(-1))
# reshape ( nbatch, nblock, block_size + 2, dim )
# -> ( nbatch * nblock, block_size + 2, dim )
x = x.view(-1, x.size(-2), x.size(-1))
if mask is not None:
mask = mask.view(-1, mask.size(-2), mask.size(-1))
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
residual = x
if self.normalize_before:
x = self.norm1(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(self.self_attn(x_q, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
# convolution module
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x = residual + self.dropout(self.conv_module(x))
if not self.normalize_before:
x = self.norm_conv(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if self.conv_module is not None:
x = self.norm_final(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
# reshape ( nbatch * nblock, block_size + 2, dim )
# -> ( nbatch, nblock, block_size + 2, dim )
x = x.view(nbatch, nblock, x.size(-2), x.size(-1))
if mask is not None:
mask = mask.view(nbatch, nblock, mask.size(-2), mask.size(-1))
# Propagete context information (the last frame of each block)
# to the first frame
# of the next block
if not is_short_segment:
if past_ctx is None:
# First block of an utterance
x[:, 0, 0, :] = x[:, 0, -1, :]
else:
x[:, 0, 0, :] = past_ctx[:, layer_idx, :]
if nblock > 1:
x[:, 1:, 0, :] = x[:, 0:-1, -1, :]
next_ctx[:, layer_idx, :] = x[:, -1, -1, :]
else:
next_ctx = None
return x, mask, True, past_ctx, next_ctx, is_short_segment, layer_idx + 1
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""ConvolutionModule definition."""
from torch import nn
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
"""
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
def forward(self, x):
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x)
return x.transpose(1, 2)
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder definition."""
import logging
import torch
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule
from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer
from espnet.nets.pytorch_backend.nets_utils import get_activation
from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L
from espnet.nets.pytorch_backend.transformer.attention import (
LegacyRelPositionMultiHeadedAttention,
MultiHeadedAttention,
RelPositionMultiHeadedAttention,
)
from espnet.nets.pytorch_backend.transformer.embedding import (
LegacyRelPositionalEncoding,
PositionalEncoding,
RelPositionalEncoding,
ScaledPositionalEncoding,
)
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import (
Conv1dLinear,
MultiLayeredConv1d,
)
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
class Encoder(torch.nn.Module):
"""Conformer encoder module.
Args:
idim (int): Input dimension.
attention_dim (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
attention_dropout_rate (float): Dropout rate in attention.
input_layer (Union[str, torch.nn.Module]): Input layer type.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
macaron_style (bool): Whether to use macaron style for positionwise layer.
pos_enc_layer_type (str): Encoder positional encoding layer type.
selfattention_layer_type (str): Encoder attention layer type.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
indices start from 1.
if not None, intermediate outputs are returned (which changes return type
signature.)
"""
def __init__(
self,
idim,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
macaron_style=False,
pos_enc_layer_type="abs_pos",
selfattention_layer_type="selfattn",
activation_type="swish",
use_cnn_module=False,
zero_triu=False,
cnn_module_kernel=31,
padding_idx=-1,
stochastic_depth_rate=0.0,
intermediate_layers=None,
ctc_softmax=None,
conditioning_layer_dim=None,
):
"""Construct an Encoder object."""
super(Encoder, self).__init__()
activation = get_activation(activation_type)
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == "rel_pos":
assert selfattention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "legacy_rel_pos":
pos_enc_class = LegacyRelPositionalEncoding
assert selfattention_layer_type == "legacy_rel_selfattn"
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
self.conv_subsampling_factor = 1
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(idim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(
idim,
attention_dim,
dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate),
)
self.conv_subsampling_factor = 4
elif input_layer == "vgg2l":
self.embed = VGG2L(idim, attention_dim)
self.conv_subsampling_factor = 4
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(attention_dim, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
# self-attention module definition
if selfattention_layer_type == "selfattn":
logging.info("encoder self-attention layer type = self-attention")
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
attention_dim,
attention_dropout_rate,
)
elif selfattention_layer_type == "legacy_rel_selfattn":
assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
attention_dim,
attention_dropout_rate,
)
elif selfattention_layer_type == "rel_selfattn":
logging.info("encoder self-attention layer type = relative self-attention")
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
attention_dim,
attention_dropout_rate,
zero_triu,
)
else:
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
# feed-forward module definition
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
attention_dim,
linear_units,
dropout_rate,
activation,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
# convolution module definition
convolution_layer = ConvolutionModule
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
attention_dim,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate * float(1 + lnum) / num_blocks,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
self.intermediate_layers = intermediate_layers
self.use_conditioning = True if ctc_softmax is not None else False
if self.use_conditioning:
self.ctc_softmax = ctc_softmax
self.conditioning_layer = torch.nn.Linear(
conditioning_layer_dim, attention_dim
)
def forward(self, xs, masks):
"""Encode input sequence.
Args:
xs (torch.Tensor): Input tensor (#batch, time, idim).
masks (torch.Tensor): Mask tensor (#batch, 1, time).
Returns:
torch.Tensor: Output tensor (#batch, time, attention_dim).
torch.Tensor: Mask tensor (#batch, 1, time).
"""
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
if self.intermediate_layers is None:
xs, masks = self.encoders(xs, masks)
else:
intermediate_outputs = []
for layer_idx, encoder_layer in enumerate(self.encoders):
xs, masks = encoder_layer(xs, masks)
if (
self.intermediate_layers is not None
and layer_idx + 1 in self.intermediate_layers
):
# intermediate branches also require normalization.
encoder_output = xs
if isinstance(encoder_output, tuple):
encoder_output = encoder_output[0]
if self.normalize_before:
encoder_output = self.after_norm(encoder_output)
intermediate_outputs.append(encoder_output)
if self.use_conditioning:
intermediate_result = self.ctc_softmax(encoder_output)
if isinstance(xs, tuple):
x, pos_emb = xs[0], xs[1]
x = x + self.conditioning_layer(intermediate_result)
xs = (x, pos_emb)
else:
xs = xs + self.conditioning_layer(intermediate_result)
if isinstance(xs, tuple):
xs = xs[0]
if self.normalize_before:
xs = self.after_norm(xs)
if self.intermediate_layers is not None:
return xs, masks, intermediate_outputs
return xs, masks
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder self-attention layer definition."""
import torch
from torch import nn
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
class EncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
stochastic_depth_rate (float): Proability to skip this layer.
During training, the layer may skip residual computation and return input
as-is with given probability.
"""
def __init__(
self,
size,
self_attn,
feed_forward,
feed_forward_macaron,
conv_module,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = LayerNorm(size) # for the FNN module
self.norm_mha = LayerNorm(size) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = LayerNorm(size)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = LayerNorm(size) # for the CNN module
self.norm_final = LayerNorm(size) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
def forward(self, x_input, mask, cache=None):
"""Compute encoded features.
Args:
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
- w/o pos emb: Tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, 1, time).
"""
if isinstance(x_input, tuple):
x, pos_emb = x_input[0], x_input[1]
else:
x, pos_emb = x_input, None
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
self.feed_forward_macaron(x)
)
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if pos_emb is not None:
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
else:
x_att = self.self_attn(x_q, x, x, mask)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = residual + stoch_layer_coeff * self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
self.feed_forward(x)
)
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Swish() activation function for Conformer."""
import torch
class Swish(torch.nn.Module):
"""Construct an Swish object."""
def forward(self, x):
"""Return Swich activation function."""
return x * torch.sigmoid(x)
import logging
import numpy as np
import torch
import torch.nn.functional as F
from packaging.version import parse as V
from espnet.nets.pytorch_backend.nets_utils import to_device
class CTC(torch.nn.Module):
"""CTC module
:param int odim: dimension of outputs
:param int eprojs: number of encoder projection units
:param float dropout_rate: dropout rate (0.0 ~ 1.0)
:param str ctc_type: builtin
:param bool reduce: reduce the CTC loss into a scalar
"""
def __init__(self, odim, eprojs, dropout_rate, ctc_type="builtin", reduce=True):
super().__init__()
self.dropout_rate = dropout_rate
self.loss = None
self.ctc_lo = torch.nn.Linear(eprojs, odim)
self.dropout = torch.nn.Dropout(dropout_rate)
self.probs = None # for visualization
# In case of Pytorch >= 1.7.0, CTC will be always builtin
self.ctc_type = ctc_type if V(torch.__version__) < V("1.7.0") else "builtin"
if ctc_type != self.ctc_type:
logging.warning(f"CTC was set to {self.ctc_type} due to PyTorch version.")
if self.ctc_type == "builtin":
reduction_type = "sum" if reduce else "none"
self.ctc_loss = torch.nn.CTCLoss(
reduction=reduction_type, zero_infinity=True
)
elif self.ctc_type == "cudnnctc":
reduction_type = "sum" if reduce else "none"
self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type)
elif self.ctc_type == "gtnctc":
from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction
self.ctc_loss = GTNCTCLossFunction.apply
else:
raise ValueError(
'ctc_type must be "builtin" or "gtnctc": {}'.format(self.ctc_type)
)
self.ignore_id = -1
self.reduce = reduce
def loss_fn(self, th_pred, th_target, th_ilen, th_olen):
if self.ctc_type in ["builtin", "cudnnctc"]:
th_pred = th_pred.log_softmax(2)
# Use the deterministic CuDNN implementation of CTC loss to avoid
# [issue#17798](https://github.com/pytorch/pytorch/issues/17798)
with torch.backends.cudnn.flags(deterministic=True):
loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen)
# Batch-size average
loss = loss / th_pred.size(1)
return loss
elif self.ctc_type == "gtnctc":
targets = [t.tolist() for t in th_target]
log_probs = torch.nn.functional.log_softmax(th_pred, dim=2)
return self.ctc_loss(log_probs, targets, th_ilen, 0, "none")
else:
raise NotImplementedError
def forward(self, hs_pad, hlens, ys_pad):
"""CTC forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, Lmax)
:return: ctc loss value
:rtype: torch.Tensor
"""
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
# zero padding for hs
ys_hat = self.ctc_lo(self.dropout(hs_pad))
if self.ctc_type != "gtnctc":
ys_hat = ys_hat.transpose(0, 1)
if self.ctc_type == "builtin":
olens = to_device(ys_hat, torch.LongTensor([len(s) for s in ys]))
hlens = hlens.long()
ys_pad = torch.cat(ys) # without this the code breaks for asr_mix
self.loss = self.loss_fn(ys_hat, ys_pad, hlens, olens)
else:
self.loss = None
hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32))
olens = torch.from_numpy(
np.fromiter((x.size(0) for x in ys), dtype=np.int32)
)
# zero padding for ys
ys_true = torch.cat(ys).cpu().int() # batch x olen
# get ctc loss
# expected shape of seqLength x batchSize x alphabet_size
dtype = ys_hat.dtype
if self.ctc_type == "cudnnctc":
# use GPU when using the cuDNN implementation
ys_true = to_device(hs_pad, ys_true)
if self.ctc_type == "gtnctc":
# keep as list for gtn
ys_true = ys
self.loss = to_device(
hs_pad, self.loss_fn(ys_hat, ys_true, hlens, olens)
).to(dtype=dtype)
# get length info
logging.info(
self.__class__.__name__
+ " input lengths: "
+ "".join(str(hlens).split("\n"))
)
logging.info(
self.__class__.__name__
+ " output lengths: "
+ "".join(str(olens).split("\n"))
)
if self.reduce:
self.loss = self.loss.sum()
logging.info("ctc loss:" + str(float(self.loss)))
return self.loss
def softmax(self, hs_pad):
"""softmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
:return: log softmax applied 3d tensor (B, Tmax, odim)
:rtype: torch.Tensor
"""
self.probs = F.softmax(self.ctc_lo(hs_pad), dim=2)
return self.probs
def log_softmax(self, hs_pad):
"""log_softmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
:return: log softmax applied 3d tensor (B, Tmax, odim)
:rtype: torch.Tensor
"""
return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
def argmax(self, hs_pad):
"""argmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
:return: argmax applied 2d tensor (B, Tmax)
:rtype: torch.Tensor
"""
return torch.argmax(self.ctc_lo(hs_pad), dim=2)
def forced_align(self, h, y, blank_id=0):
"""forced alignment.
:param torch.Tensor h: hidden state sequence, 2d tensor (T, D)
:param torch.Tensor y: id sequence tensor 1d tensor (L)
:param int y: blank symbol index
:return: best alignment results
:rtype: list
"""
def interpolate_blank(label, blank_id=0):
"""Insert blank token between every two label token."""
label = np.expand_dims(label, 1)
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1)
label = label.reshape(-1)
label = np.append(label, label[0])
return label
lpz = self.log_softmax(h)
lpz = lpz.squeeze(0)
y_int = interpolate_blank(y, blank_id)
logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0 # log of zero
state_path = (
np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1
) # state path
logdelta[0, 0] = lpz[0][y_int[0]]
logdelta[0, 1] = lpz[0][y_int[1]]
for t in range(1, lpz.size(0)):
for s in range(len(y_int)):
if y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]:
candidates = np.array([logdelta[t - 1, s], logdelta[t - 1, s - 1]])
prev_state = [s, s - 1]
else:
candidates = np.array(
[
logdelta[t - 1, s],
logdelta[t - 1, s - 1],
logdelta[t - 1, s - 2],
]
)
prev_state = [s, s - 1, s - 2]
logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]]
state_path[t, s] = prev_state[np.argmax(candidates)]
state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16)
candidates = np.array(
[logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]]
)
prev_state = [len(y_int) - 1, len(y_int) - 2]
state_seq[-1] = prev_state[np.argmax(candidates)]
for t in range(lpz.size(0) - 2, -1, -1):
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
output_state_seq = []
for t in range(0, lpz.size(0)):
output_state_seq.append(y_int[state_seq[t, 0]])
return output_state_seq
def ctc_for(args, odim, reduce=True):
"""Returns the CTC module for the given args and output dimension
:param Namespace args: the program args
:param int odim : The output dimension
:param bool reduce : return the CTC loss in a scalar
:return: the corresponding CTC module
"""
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
if num_encs == 1:
# compatible with single encoder asr mode
return CTC(
odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=reduce
)
elif num_encs >= 1:
ctcs_list = torch.nn.ModuleList()
if args.share_ctc:
# use dropout_rate of the first encoder
ctc = CTC(
odim,
args.eprojs,
args.dropout_rate[0],
ctc_type=args.ctc_type,
reduce=reduce,
)
ctcs_list.append(ctc)
else:
for idx in range(num_encs):
ctc = CTC(
odim,
args.eprojs,
args.dropout_rate[idx],
ctc_type=args.ctc_type,
reduce=reduce,
)
ctcs_list.append(ctc)
return ctcs_list
else:
raise ValueError(
"Number of encoders needs to be more than one. {}".format(num_encs)
)
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""RNN sequence-to-sequence speech recognition model (pytorch)."""
import argparse
import logging
import math
import os
from itertools import groupby
import chainer
import numpy as np
import torch
from chainer import reporter
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.e2e_asr_common import label_smoothing_dist
from espnet.nets.pytorch_backend.ctc import ctc_for
from espnet.nets.pytorch_backend.frontends.feature_transform import ( # noqa: H301
feature_transform_for,
)
from espnet.nets.pytorch_backend.frontends.frontend import frontend_for
from espnet.nets.pytorch_backend.initialization import (
lecun_normal_init_parameters,
set_forget_bias_to_one,
)
from espnet.nets.pytorch_backend.nets_utils import (
get_subsample,
pad_list,
to_device,
to_torch_tensor,
)
from espnet.nets.pytorch_backend.rnn.argument import ( # noqa: H301
add_arguments_rnn_attention_common,
add_arguments_rnn_decoder_common,
add_arguments_rnn_encoder_common,
)
from espnet.nets.pytorch_backend.rnn.attentions import att_for
from espnet.nets.pytorch_backend.rnn.decoders import decoder_for
from espnet.nets.pytorch_backend.rnn.encoders import encoder_for
from espnet.nets.scorers.ctc import CTCPrefixScorer
from espnet.utils.fill_missing_args import fill_missing_args
CTC_LOSS_THRESHOLD = 10000
class Reporter(chainer.Chain):
"""A chainer reporter wrapper."""
def report(self, loss_ctc, loss_att, acc, cer_ctc, cer, wer, mtl_loss):
"""Report at every step."""
reporter.report({"loss_ctc": loss_ctc}, self)
reporter.report({"loss_att": loss_att}, self)
reporter.report({"acc": acc}, self)
reporter.report({"cer_ctc": cer_ctc}, self)
reporter.report({"cer": cer}, self)
reporter.report({"wer": wer}, self)
logging.info("mtl loss:" + str(mtl_loss))
reporter.report({"loss": mtl_loss}, self)
class E2E(ASRInterface, torch.nn.Module):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments."""
E2E.encoder_add_arguments(parser)
E2E.attention_add_arguments(parser)
E2E.decoder_add_arguments(parser)
return parser
@staticmethod
def encoder_add_arguments(parser):
"""Add arguments for the encoder."""
group = parser.add_argument_group("E2E encoder setting")
group = add_arguments_rnn_encoder_common(group)
return parser
@staticmethod
def attention_add_arguments(parser):
"""Add arguments for the attention."""
group = parser.add_argument_group("E2E attention setting")
group = add_arguments_rnn_attention_common(group)
return parser
@staticmethod
def decoder_add_arguments(parser):
"""Add arguments for the decoder."""
group = parser.add_argument_group("E2E decoder setting")
group = add_arguments_rnn_decoder_common(group)
return parser
def get_total_subsampling_factor(self):
"""Get total subsampling factor."""
if isinstance(self.enc, torch.nn.ModuleList):
return self.enc[0].conv_subsampling_factor * int(np.prod(self.subsample))
else:
return self.enc.conv_subsampling_factor * int(np.prod(self.subsample))
def __init__(self, idim, odim, args):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
super(E2E, self).__init__()
torch.nn.Module.__init__(self)
# fill missing arguments for compatibility
args = fill_missing_args(args, self.add_arguments)
self.mtlalpha = args.mtlalpha
assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
self.etype = args.etype
self.verbose = args.verbose
# NOTE: for self.build method
args.char_list = getattr(args, "char_list", None)
self.char_list = args.char_list
self.outdir = args.outdir
self.space = args.sym_space
self.blank = args.sym_blank
self.reporter = Reporter()
# below means the last number becomes eos/sos ID
# note that sos/eos IDs are identical
self.sos = odim - 1
self.eos = odim - 1
# subsample info
self.subsample = get_subsample(args, mode="asr", arch="rnn")
# label smoothing info
if args.lsm_type and os.path.isfile(args.train_json):
logging.info("Use label smoothing with " + args.lsm_type)
labeldist = label_smoothing_dist(
odim, args.lsm_type, transcript=args.train_json
)
else:
labeldist = None
if getattr(args, "use_frontend", False): # use getattr to keep compatibility
self.frontend = frontend_for(args, idim)
self.feature_transform = feature_transform_for(args, (idim - 1) * 2)
idim = args.n_mels
else:
self.frontend = None
# encoder
self.enc = encoder_for(args, idim, self.subsample)
# ctc
self.ctc = ctc_for(args, odim)
# attention
self.att = att_for(args)
# decoder
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist)
# weight initialization
self.init_like_chainer()
# options for beam search
if args.report_cer or args.report_wer:
recog_args = {
"beam_size": args.beam_size,
"penalty": args.penalty,
"ctc_weight": args.ctc_weight,
"maxlenratio": args.maxlenratio,
"minlenratio": args.minlenratio,
"lm_weight": args.lm_weight,
"rnnlm": args.rnnlm,
"nbest": args.nbest,
"space": args.sym_space,
"blank": args.sym_blank,
}
self.recog_args = argparse.Namespace(**recog_args)
self.report_cer = args.report_cer
self.report_wer = args.report_wer
else:
self.report_cer = False
self.report_wer = False
self.rnnlm = None
self.logzero = -10000000000.0
self.loss = None
self.acc = None
def init_like_chainer(self):
"""Initialize weight like chainer.
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
however, there are two exceptions as far as I know.
- EmbedID.W ~ Normal(0, 1)
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
"""
lecun_normal_init_parameters(self)
# exceptions
# embed weight ~ Normal(0, 1)
self.dec.embed.weight.data.normal_(0, 1)
# forget-bias = 1.0
# https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
for i in range(len(self.dec.decoder)):
set_forget_bias_to_one(self.dec.decoder[i].bias_ih)
def forward(self, xs_pad, ilens, ys_pad):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: loss value
:rtype: torch.Tensor
"""
import editdistance
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
hs_pad, hlens = self.feature_transform(hs_pad, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
# 2. CTC loss
if self.mtlalpha == 0:
self.loss_ctc = None
else:
self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad)
# 3. attention loss
if self.mtlalpha == 1:
self.loss_att, acc = None, None
else:
self.loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad)
self.acc = acc
# 4. compute cer without beam search
if self.mtlalpha == 0 or self.char_list is None:
cer_ctc = None
else:
cers = []
y_hats = self.ctc.argmax(hs_pad).data
for i, y in enumerate(y_hats):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[i]
seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
hyp_chars = seq_hat_text.replace(" ", "")
ref_chars = seq_true_text.replace(" ", "")
if len(ref_chars) > 0:
cers.append(
editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)
)
cer_ctc = sum(cers) / len(cers) if cers else None
# 5. compute cer/wer
if self.training or not (self.report_cer or self.report_wer):
cer, wer = 0.0, 0.0
# oracle_cer, oracle_wer = 0.0, 0.0
else:
if self.recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(hs_pad).data
else:
lpz = None
word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], []
nbest_hyps = self.dec.recognize_beam_batch(
hs_pad,
torch.tensor(hlens),
lpz,
self.recog_args,
self.char_list,
self.rnnlm,
)
# remove <sos> and <eos>
y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps]
for i, y_hat in enumerate(y_hats):
y_true = ys_pad[i]
seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ")
seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "")
seq_true_text = "".join(seq_true).replace(self.recog_args.space, " ")
hyp_words = seq_hat_text.split()
ref_words = seq_true_text.split()
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
hyp_chars = seq_hat_text.replace(" ", "")
ref_chars = seq_true_text.replace(" ", "")
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
char_ref_lens.append(len(ref_chars))
wer = (
0.0
if not self.report_wer
else float(sum(word_eds)) / sum(word_ref_lens)
)
cer = (
0.0
if not self.report_cer
else float(sum(char_eds)) / sum(char_ref_lens)
)
alpha = self.mtlalpha
if alpha == 0:
self.loss = self.loss_att
loss_att_data = float(self.loss_att)
loss_ctc_data = None
elif alpha == 1:
self.loss = self.loss_ctc
loss_att_data = None
loss_ctc_data = float(self.loss_ctc)
else:
self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att
loss_att_data = float(self.loss_att)
loss_ctc_data = float(self.loss_ctc)
loss_data = float(self.loss)
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
self.reporter.report(
loss_ctc_data, loss_att_data, acc, cer_ctc, cer, wer, loss_data
)
else:
logging.warning("loss (=%f) is not correct", loss_data)
return self.loss
def scorers(self):
"""Scorers."""
return dict(decoder=self.dec, ctc=CTCPrefixScorer(self.ctc, self.eos))
def encode(self, x):
"""Encode acoustic features.
:param ndarray x: input acoustic feature (T, D)
:return: encoder outputs
:rtype: torch.Tensor
"""
self.eval()
ilens = [x.shape[0]]
# subsample frame
x = x[:: self.subsample[0], :]
p = next(self.parameters())
h = torch.as_tensor(x, device=p.device, dtype=p.dtype)
# make a utt list (1) to use the same interface for encoder
hs = h.contiguous().unsqueeze(0)
# 0. Frontend
if self.frontend is not None:
enhanced, hlens, mask = self.frontend(hs, ilens)
hs, hlens = self.feature_transform(enhanced, hlens)
else:
hs, hlens = hs, ilens
# 1. encoder
hs, _, _ = self.enc(hs, hlens)
return hs.squeeze(0)
def recognize(self, x, recog_args, char_list, rnnlm=None):
"""E2E beam search.
:param ndarray x: input acoustic feature (T, D)
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
hs = self.encode(x).unsqueeze(0)
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(hs)[0]
else:
lpz = None
# 2. Decoder
# decode the first utterance
y = self.dec.recognize_beam(hs[0], lpz, recog_args, char_list, rnnlm)
return y
def recognize_batch(self, xs, recog_args, char_list, rnnlm=None):
"""E2E batch beam search.
:param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...]
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev = self.training
self.eval()
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
# subsample frame
xs = [xx[:: self.subsample[0], :] for xx in xs]
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
xs_pad = pad_list(xs, 0.0)
# 0. Frontend
if self.frontend is not None:
enhanced, hlens, mask = self.frontend(xs_pad, ilens)
hs_pad, hlens = self.feature_transform(enhanced, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = self.ctc.log_softmax(hs_pad)
normalize_score = False
else:
lpz = None
normalize_score = True
# 2. Decoder
hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor
y = self.dec.recognize_beam_batch(
hs_pad,
hlens,
lpz,
recog_args,
char_list,
rnnlm,
normalize_score=normalize_score,
)
if prev:
self.train()
return y
def enhance(self, xs):
"""Forward only in the frontend stage.
:param ndarray xs: input acoustic feature (T, C, F)
:return: enhaned feature
:rtype: torch.Tensor
"""
if self.frontend is None:
raise RuntimeError("Frontend does't exist")
prev = self.training
self.eval()
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
# subsample frame
xs = [xx[:: self.subsample[0], :] for xx in xs]
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
xs_pad = pad_list(xs, 0.0)
enhanced, hlensm, mask = self.frontend(xs_pad, ilens)
if prev:
self.train()
return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
self.eval()
with torch.no_grad():
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
hs_pad, hlens = self.feature_transform(hs_pad, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
hpad, hlens, _ = self.enc(hs_pad, hlens)
# 2. Decoder
att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad)
self.train()
return att_ws
def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad):
"""E2E CTC probability calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: CTC probability (B, Tmax, vocab)
:rtype: float ndarray
"""
probs = None
if self.mtlalpha == 0:
return probs
self.eval()
with torch.no_grad():
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
hs_pad, hlens = self.feature_transform(hs_pad, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
hpad, hlens, _ = self.enc(hs_pad, hlens)
# 2. CTC probs
probs = self.ctc.softmax(hpad).cpu().numpy()
self.train()
return probs
def subsample_frames(self, x):
"""Subsample speeh frames in the encoder."""
# subsample frame
x = x[:: self.subsample[0], :]
ilen = [x.shape[0]]
h = to_device(self, torch.from_numpy(np.array(x, dtype=np.float32)))
h.contiguous()
return h, ilen
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
Conformer speech recognition model (pytorch).
It is a fusion of `e2e_asr_transformer.py`
Refer to: https://arxiv.org/abs/2005.08100
"""
from espnet.nets.pytorch_backend.conformer.argument import ( # noqa: H301
add_arguments_conformer_common,
verify_rel_pos_type,
)
from espnet.nets.pytorch_backend.conformer.encoder import Encoder
from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E as E2ETransformer
class E2E(E2ETransformer):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments."""
E2ETransformer.add_arguments(parser)
E2E.add_conformer_arguments(parser)
return parser
@staticmethod
def add_conformer_arguments(parser):
"""Add arguments for conformer model."""
group = parser.add_argument_group("conformer model specific setting")
group = add_arguments_conformer_common(group)
return parser
def __init__(self, idim, odim, args, ignore_id=-1):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
super().__init__(idim, odim, args, ignore_id)
if args.transformer_attn_dropout_rate is None:
args.transformer_attn_dropout_rate = args.dropout_rate
# Check the relative positional encoding type
args = verify_rel_pos_type(args)
self.encoder = Encoder(
idim=idim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.eunits,
num_blocks=args.elayers,
input_layer=args.transformer_input_layer,
dropout_rate=args.dropout_rate,
positional_dropout_rate=args.dropout_rate,
attention_dropout_rate=args.transformer_attn_dropout_rate,
pos_enc_layer_type=args.transformer_encoder_pos_enc_layer_type,
selfattention_layer_type=args.transformer_encoder_selfattn_layer_type,
activation_type=args.transformer_encoder_activation_type,
macaron_style=args.macaron_style,
use_cnn_module=args.use_cnn_module,
zero_triu=args.zero_triu,
cnn_module_kernel=args.cnn_module_kernel,
stochastic_depth_rate=args.stochastic_depth_rate,
intermediate_layers=self.intermediate_ctc_layers,
ctc_softmax=self.ctc.softmax if args.self_conditioning else None,
conditioning_layer_dim=odim,
)
self.reset_parameters(args)
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Waseda University (Yosuke Higuchi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
Mask CTC based non-autoregressive speech recognition model (pytorch).
See https://arxiv.org/abs/2005.08700 for the detail.
"""
import logging
import math
from distutils.util import strtobool
from itertools import groupby
import numpy
import torch
from espnet.nets.pytorch_backend.conformer.argument import ( # noqa: H301
add_arguments_conformer_common,
)
from espnet.nets.pytorch_backend.conformer.encoder import Encoder
from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD
from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E as E2ETransformer
from espnet.nets.pytorch_backend.maskctc.add_mask_token import mask_uniform
from espnet.nets.pytorch_backend.maskctc.mask import square_mask
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask, th_accuracy
class E2E(E2ETransformer):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments."""
E2ETransformer.add_arguments(parser)
E2E.add_maskctc_arguments(parser)
return parser
@staticmethod
def add_maskctc_arguments(parser):
"""Add arguments for maskctc model."""
group = parser.add_argument_group("maskctc specific setting")
group.add_argument(
"--maskctc-use-conformer-encoder",
default=False,
type=strtobool,
)
group = add_arguments_conformer_common(group)
return parser
def __init__(self, idim, odim, args, ignore_id=-1):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
odim += 1 # for the mask token
super().__init__(idim, odim, args, ignore_id)
assert 0.0 <= self.mtlalpha < 1.0, "mtlalpha should be [0.0, 1.0)"
self.mask_token = odim - 1
self.sos = odim - 2
self.eos = odim - 2
self.odim = odim
self.intermediate_ctc_weight = args.intermediate_ctc_weight
self.intermediate_ctc_layers = None
if args.intermediate_ctc_layer != "":
self.intermediate_ctc_layers = [
int(i) for i in args.intermediate_ctc_layer.split(",")
]
if args.maskctc_use_conformer_encoder:
if args.transformer_attn_dropout_rate is None:
args.transformer_attn_dropout_rate = args.conformer_dropout_rate
self.encoder = Encoder(
idim=idim,
attention_dim=args.adim,
attention_heads=args.aheads,
linear_units=args.eunits,
num_blocks=args.elayers,
input_layer=args.transformer_input_layer,
dropout_rate=args.dropout_rate,
positional_dropout_rate=args.dropout_rate,
attention_dropout_rate=args.transformer_attn_dropout_rate,
pos_enc_layer_type=args.transformer_encoder_pos_enc_layer_type,
selfattention_layer_type=args.transformer_encoder_selfattn_layer_type,
activation_type=args.transformer_encoder_activation_type,
macaron_style=args.macaron_style,
use_cnn_module=args.use_cnn_module,
cnn_module_kernel=args.cnn_module_kernel,
stochastic_depth_rate=args.stochastic_depth_rate,
intermediate_layers=self.intermediate_ctc_layers,
)
self.reset_parameters(args)
def forward(self, xs_pad, ilens, ys_pad):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of source sequences (B)
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:return: ctc loss value
:rtype: torch.Tensor
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy in attention decoder
:rtype: float
"""
# 1. forward encoder
xs_pad = xs_pad[:, : max(ilens)] # for data parallel
src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2)
if self.intermediate_ctc_layers:
hs_pad, hs_mask, hs_intermediates = self.encoder(xs_pad, src_mask)
else:
hs_pad, hs_mask = self.encoder(xs_pad, src_mask)
self.hs_pad = hs_pad
# 2. forward decoder
ys_in_pad, ys_out_pad = mask_uniform(
ys_pad, self.mask_token, self.eos, self.ignore_id
)
ys_mask = square_mask(ys_in_pad, self.eos)
pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask)
self.pred_pad = pred_pad
# 3. compute attention loss
loss_att = self.criterion(pred_pad, ys_out_pad)
self.acc = th_accuracy(
pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id
)
# 4. compute ctc loss
loss_ctc, cer_ctc = None, None
loss_intermediate_ctc = 0.0
if self.mtlalpha > 0:
batch_size = xs_pad.size(0)
hs_len = hs_mask.view(batch_size, -1).sum(1)
loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad)
if self.error_calculator is not None:
ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
# for visualization
if not self.training:
self.ctc.softmax(hs_pad)
if self.intermediate_ctc_weight > 0 and self.intermediate_ctc_layers:
for hs_intermediate in hs_intermediates:
# assuming hs_intermediates and hs_pad has same length / padding
loss_inter = self.ctc(
hs_intermediate.view(batch_size, -1, self.adim), hs_len, ys_pad
)
loss_intermediate_ctc += loss_inter
loss_intermediate_ctc /= len(self.intermediate_ctc_layers)
# 5. compute cer/wer
if self.training or self.error_calculator is None or self.decoder is None:
cer, wer = None, None
else:
ys_hat = pred_pad.argmax(dim=-1)
cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
alpha = self.mtlalpha
if alpha == 0:
self.loss = loss_att
loss_att_data = float(loss_att)
loss_ctc_data = None
else:
self.loss = (
alpha * loss_ctc
+ self.intermediate_ctc_weight * loss_intermediate_ctc
+ (1 - alpha - self.intermediate_ctc_weight) * loss_att
)
loss_att_data = float(loss_att)
loss_ctc_data = float(loss_ctc)
loss_data = float(self.loss)
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
self.reporter.report(
loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data
)
else:
logging.warning("loss (=%f) is not correct", loss_data)
return self.loss
def recognize(self, x, recog_args, char_list=None, rnnlm=None):
"""Recognize input speech.
:param ndnarray x: input acoustic feature (B, T, D) or (T, D)
:param Namespace recog_args: argment Namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: decoding result
:rtype: list
"""
def num2str(char_list, mask_token, mask_char="_"):
def f(yl):
cl = [char_list[y] if y != mask_token else mask_char for y in yl]
return "".join(cl).replace("<space>", " ")
return f
n2s = num2str(char_list, self.mask_token)
self.eval()
h = self.encode(x).unsqueeze(0)
input_len = h.squeeze(0)
logging.info("input lengths: " + str(input_len.size(0)))
# greedy ctc outputs
ctc_probs, ctc_ids = torch.exp(self.ctc.log_softmax(h)).max(dim=-1)
y_hat = torch.stack([x[0] for x in groupby(ctc_ids[0])])
y_idx = torch.nonzero(y_hat != 0).squeeze(-1)
# calculate token-level ctc probabilities by taking
# the maximum probability of consecutive frames with
# the same ctc symbols
probs_hat = []
cnt = 0
for i, y in enumerate(y_hat.tolist()):
probs_hat.append(-1)
while cnt < ctc_ids.shape[1] and y == ctc_ids[0][cnt]:
if probs_hat[i] < ctc_probs[0][cnt]:
probs_hat[i] = ctc_probs[0][cnt].item()
cnt += 1
probs_hat = torch.from_numpy(numpy.array(probs_hat))
# mask ctc outputs based on ctc probabilities
p_thres = recog_args.maskctc_probability_threshold
mask_idx = torch.nonzero(probs_hat[y_idx] < p_thres).squeeze(-1)
confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1)
mask_num = len(mask_idx)
y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token
y_in[0][confident_idx] = y_hat[y_idx][confident_idx]
logging.info("ctc:{}".format(n2s(y_in[0].tolist())))
# iterative decoding
if not mask_num == 0:
K = recog_args.maskctc_n_iterations
num_iter = K if mask_num >= K and K > 0 else mask_num
for t in range(num_iter - 1):
pred, _ = self.decoder(y_in, None, h, None)
pred_score, pred_id = pred[0][mask_idx].max(dim=-1)
cand = torch.topk(pred_score, mask_num // num_iter, -1)[1]
y_in[0][mask_idx[cand]] = pred_id[cand]
mask_idx = torch.nonzero(y_in[0] == self.mask_token).squeeze(-1)
logging.info("msk:{}".format(n2s(y_in[0].tolist())))
# predict leftover masks (|masks| < mask_num // num_iter)
pred, pred_mask = self.decoder(y_in, None, h, None)
y_in[0][mask_idx] = pred[0][mask_idx].argmax(dim=-1)
logging.info("msk:{}".format(n2s(y_in[0].tolist())))
ret = y_in.tolist()[0]
hyp = {"score": 0.0, "yseq": [self.sos] + ret + [self.eos]}
return [hyp]
#!/usr/bin/env python3
"""
This script is used to construct End-to-End models of multi-speaker ASR.
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import argparse
import logging
import math
import os
import sys
from itertools import groupby
import numpy as np
import torch
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.e2e_asr_common import get_vgg2l_odim, label_smoothing_dist
from espnet.nets.pytorch_backend.ctc import ctc_for
from espnet.nets.pytorch_backend.e2e_asr import E2E as E2EASR
from espnet.nets.pytorch_backend.e2e_asr import Reporter
from espnet.nets.pytorch_backend.frontends.feature_transform import ( # noqa: H301
feature_transform_for,
)
from espnet.nets.pytorch_backend.frontends.frontend import frontend_for
from espnet.nets.pytorch_backend.initialization import (
lecun_normal_init_parameters,
set_forget_bias_to_one,
)
from espnet.nets.pytorch_backend.nets_utils import (
get_subsample,
make_pad_mask,
pad_list,
to_device,
to_torch_tensor,
)
from espnet.nets.pytorch_backend.rnn.attentions import att_for
from espnet.nets.pytorch_backend.rnn.decoders import decoder_for
from espnet.nets.pytorch_backend.rnn.encoders import RNNP, VGG2L
from espnet.nets.pytorch_backend.rnn.encoders import encoder_for as encoder_for_single
CTC_LOSS_THRESHOLD = 10000
class PIT(object):
"""Permutation Invariant Training (PIT) module.
:parameter int num_spkrs: number of speakers for PIT process (2 or 3)
"""
def __init__(self, num_spkrs):
"""Initialize PIT module."""
self.num_spkrs = num_spkrs
# [[0, 1], [1, 0]] or
# [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]]
self.perm_choices = []
initial_seq = np.linspace(0, num_spkrs - 1, num_spkrs, dtype=np.int64)
self.permutationDFS(initial_seq, 0)
# [[0, 3], [1, 2]] or
# [[0, 4, 8], [0, 5, 7], [1, 3, 8], [1, 5, 6], [2, 4, 6], [2, 3, 7]]
self.loss_perm_idx = np.linspace(
0, num_spkrs * (num_spkrs - 1), num_spkrs, dtype=np.int64
).reshape(1, num_spkrs)
self.loss_perm_idx = (self.loss_perm_idx + np.array(self.perm_choices)).tolist()
def min_pit_sample(self, loss):
"""Compute the PIT loss for each sample.
:param 1-D torch.Tensor loss: list of losses for one sample,
including [h1r1, h1r2, h2r1, h2r2] or
[h1r1, h1r2, h1r3, h2r1, h2r2, h2r3, h3r1, h3r2, h3r3]
:return minimum loss of best permutation
:rtype torch.Tensor (1)
:return the best permutation
:rtype List: len=2
"""
score_perms = (
torch.stack(
[torch.sum(loss[loss_perm_idx]) for loss_perm_idx in self.loss_perm_idx]
)
/ self.num_spkrs
)
perm_loss, min_idx = torch.min(score_perms, 0)
permutation = self.perm_choices[min_idx]
return perm_loss, permutation
def pit_process(self, losses):
"""Compute the PIT loss for a batch.
:param torch.Tensor losses: losses (B, 1|4|9)
:return minimum losses of a batch with best permutation
:rtype torch.Tensor (B)
:return the best permutation
:rtype torch.LongTensor (B, 1|2|3)
"""
bs = losses.size(0)
ret = [self.min_pit_sample(losses[i]) for i in range(bs)]
loss_perm = torch.stack([r[0] for r in ret], dim=0).to(losses.device) # (B)
permutation = torch.tensor([r[1] for r in ret]).long().to(losses.device)
return torch.mean(loss_perm), permutation
def permutationDFS(self, source, start):
"""Get permutations with DFS.
The final result is all permutations of the 'source' sequence.
e.g. [[1, 2], [2, 1]] or
[[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 2, 1], [3, 1, 2]]
:param np.ndarray source: (num_spkrs, 1), e.g. [1, 2, ..., N]
:param int start: the start point to permute
"""
if start == len(source) - 1: # reach final state
self.perm_choices.append(source.tolist())
for i in range(start, len(source)):
# swap values at position start and i
source[start], source[i] = source[i], source[start]
self.permutationDFS(source, start + 1)
# reverse the swap
source[start], source[i] = source[i], source[start]
class E2E(ASRInterface, torch.nn.Module):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments."""
E2EASR.encoder_add_arguments(parser)
E2E.encoder_mix_add_arguments(parser)
E2EASR.attention_add_arguments(parser)
E2EASR.decoder_add_arguments(parser)
return parser
@staticmethod
def encoder_mix_add_arguments(parser):
"""Add arguments for multi-speaker encoder."""
group = parser.add_argument_group("E2E encoder setting for multi-speaker")
# asr-mix encoder
group.add_argument(
"--spa",
action="store_true",
help="Enable speaker parallel attention "
"for multi-speaker speech recognition task.",
)
group.add_argument(
"--elayers-sd",
default=4,
type=int,
help="Number of speaker differentiate encoder layers"
"for multi-speaker speech recognition task.",
)
return parser
def get_total_subsampling_factor(self):
"""Get total subsampling factor."""
return self.enc.conv_subsampling_factor * int(np.prod(self.subsample))
def __init__(self, idim, odim, args):
"""Initialize multi-speaker E2E module."""
super(E2E, self).__init__()
torch.nn.Module.__init__(self)
self.mtlalpha = args.mtlalpha
assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
self.etype = args.etype
self.verbose = args.verbose
# NOTE: for self.build method
args.char_list = getattr(args, "char_list", None)
self.char_list = args.char_list
self.outdir = args.outdir
self.space = args.sym_space
self.blank = args.sym_blank
self.reporter = Reporter()
self.num_spkrs = args.num_spkrs
self.spa = args.spa
self.pit = PIT(self.num_spkrs)
# below means the last number becomes eos/sos ID
# note that sos/eos IDs are identical
self.sos = odim - 1
self.eos = odim - 1
# subsample info
self.subsample = get_subsample(args, mode="asr", arch="rnn_mix")
# label smoothing info
if args.lsm_type and os.path.isfile(args.train_json):
logging.info("Use label smoothing with " + args.lsm_type)
labeldist = label_smoothing_dist(
odim, args.lsm_type, transcript=args.train_json
)
else:
labeldist = None
if getattr(args, "use_frontend", False): # use getattr to keep compatibility
self.frontend = frontend_for(args, idim)
self.feature_transform = feature_transform_for(args, (idim - 1) * 2)
idim = args.n_mels
else:
self.frontend = None
# encoder
self.enc = encoder_for(args, idim, self.subsample)
# ctc
self.ctc = ctc_for(args, odim, reduce=False)
# attention
num_att = self.num_spkrs if args.spa else 1
self.att = att_for(args, num_att)
# decoder
self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist)
# weight initialization
self.init_like_chainer()
# options for beam search
if "report_cer" in vars(args) and (args.report_cer or args.report_wer):
recog_args = {
"beam_size": args.beam_size,
"penalty": args.penalty,
"ctc_weight": args.ctc_weight,
"maxlenratio": args.maxlenratio,
"minlenratio": args.minlenratio,
"lm_weight": args.lm_weight,
"rnnlm": args.rnnlm,
"nbest": args.nbest,
"space": args.sym_space,
"blank": args.sym_blank,
}
self.recog_args = argparse.Namespace(**recog_args)
self.report_cer = args.report_cer
self.report_wer = args.report_wer
else:
self.report_cer = False
self.report_wer = False
self.rnnlm = None
self.logzero = -10000000000.0
self.loss = None
self.acc = None
def init_like_chainer(self):
"""Initialize weight like chainer.
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
however, there are two exceptions as far as I know.
- EmbedID.W ~ Normal(0, 1)
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
"""
lecun_normal_init_parameters(self)
# exceptions
# embed weight ~ Normal(0, 1)
self.dec.embed.weight.data.normal_(0, 1)
# forget-bias = 1.0
# https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
for i in range(len(self.dec.decoder)):
set_forget_bias_to_one(self.dec.decoder[i].bias_ih)
def forward(self, xs_pad, ilens, ys_pad):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, num_spkrs, Lmax)
:return: ctc loss value
:rtype: torch.Tensor
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy in attention decoder
:rtype: float
"""
import editdistance
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
if isinstance(hs_pad, list):
hlens_n = [None] * self.num_spkrs
for i in range(self.num_spkrs):
hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens)
hlens = hlens_n
else:
hs_pad, hlens = self.feature_transform(hs_pad, hlens)
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
if not isinstance(
hs_pad, list
): # single-channel input xs_pad (single- or multi-speaker)
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
else: # multi-channel multi-speaker input xs_pad
for i in range(self.num_spkrs):
hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i])
# 2. CTC loss
if self.mtlalpha == 0:
loss_ctc, min_perm = None, None
else:
if not isinstance(hs_pad, list): # single-speaker input xs_pad
loss_ctc = torch.mean(self.ctc(hs_pad, hlens, ys_pad))
else: # multi-speaker input xs_pad
ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax)
loss_ctc_perm = torch.stack(
[
self.ctc(
hs_pad[i // self.num_spkrs],
hlens[i // self.num_spkrs],
ys_pad[i % self.num_spkrs],
)
for i in range(self.num_spkrs**2)
],
dim=1,
) # (B, num_spkrs^2)
loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm)
logging.info("ctc loss:" + str(float(loss_ctc)))
# 3. attention loss
if self.mtlalpha == 1:
loss_att = None
acc = None
else:
if not isinstance(hs_pad, list): # single-speaker input xs_pad
loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad)
else:
for i in range(ys_pad.size(1)): # B
ys_pad[:, i] = ys_pad[min_perm[i], i]
rslt = [
self.dec(hs_pad[i], hlens[i], ys_pad[i], strm_idx=i)
for i in range(self.num_spkrs)
]
loss_att = sum([r[0] for r in rslt]) / float(len(rslt))
acc = sum([r[1] for r in rslt]) / float(len(rslt))
self.acc = acc
# 4. compute cer without beam search
if self.mtlalpha == 0 or self.char_list is None:
cer_ctc = None
else:
cers = []
for ns in range(self.num_spkrs):
y_hats = self.ctc.argmax(hs_pad[ns]).data
for i, y in enumerate(y_hats):
y_hat = [x[0] for x in groupby(y)]
y_true = ys_pad[ns][i]
seq_hat = [
self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
hyp_chars = seq_hat_text.replace(" ", "")
ref_chars = seq_true_text.replace(" ", "")
if len(ref_chars) > 0:
cers.append(
editdistance.eval(hyp_chars, ref_chars) / len(ref_chars)
)
cer_ctc = sum(cers) / len(cers) if cers else None
# 5. compute cer/wer
if (
self.training
or not (self.report_cer or self.report_wer)
or not isinstance(hs_pad, list)
):
cer, wer = 0.0, 0.0
else:
if self.recog_args.ctc_weight > 0.0:
lpz = [
self.ctc.log_softmax(hs_pad[i]).data for i in range(self.num_spkrs)
]
else:
lpz = None
word_eds, char_eds, word_ref_lens, char_ref_lens = [], [], [], []
nbest_hyps = [
self.dec.recognize_beam_batch(
hs_pad[i],
torch.tensor(hlens[i]),
lpz[i],
self.recog_args,
self.char_list,
self.rnnlm,
strm_idx=i,
)
for i in range(self.num_spkrs)
]
# remove <sos> and <eos>
y_hats = [
[nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps[i]]
for i in range(self.num_spkrs)
]
for i in range(len(y_hats[0])):
hyp_words = []
hyp_chars = []
ref_words = []
ref_chars = []
for ns in range(self.num_spkrs):
y_hat = y_hats[ns][i]
y_true = ys_pad[ns][i]
seq_hat = [
self.char_list[int(idx)] for idx in y_hat if int(idx) != -1
]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ")
seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "")
seq_true_text = "".join(seq_true).replace(
self.recog_args.space, " "
)
hyp_words.append(seq_hat_text.split())
ref_words.append(seq_true_text.split())
hyp_chars.append(seq_hat_text.replace(" ", ""))
ref_chars.append(seq_true_text.replace(" ", ""))
tmp_word_ed = [
editdistance.eval(
hyp_words[ns // self.num_spkrs], ref_words[ns % self.num_spkrs]
)
for ns in range(self.num_spkrs**2)
] # h1r1,h1r2,h2r1,h2r2
tmp_char_ed = [
editdistance.eval(
hyp_chars[ns // self.num_spkrs], ref_chars[ns % self.num_spkrs]
)
for ns in range(self.num_spkrs**2)
] # h1r1,h1r2,h2r1,h2r2
word_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_word_ed))[0])
word_ref_lens.append(len(sum(ref_words, [])))
char_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_char_ed))[0])
char_ref_lens.append(len("".join(ref_chars)))
wer = (
0.0
if not self.report_wer
else float(sum(word_eds)) / sum(word_ref_lens)
)
cer = (
0.0
if not self.report_cer
else float(sum(char_eds)) / sum(char_ref_lens)
)
alpha = self.mtlalpha
if alpha == 0:
self.loss = loss_att
loss_att_data = float(loss_att)
loss_ctc_data = None
elif alpha == 1:
self.loss = loss_ctc
loss_att_data = None
loss_ctc_data = float(loss_ctc)
else:
self.loss = alpha * loss_ctc + (1 - alpha) * loss_att
loss_att_data = float(loss_att)
loss_ctc_data = float(loss_ctc)
loss_data = float(self.loss)
if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data):
self.reporter.report(
loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data
)
else:
logging.warning("loss (=%f) is not correct", loss_data)
return self.loss
def recognize(self, x, recog_args, char_list, rnnlm=None):
"""E2E beam search.
:param ndarray x: input acoustic feature (T, D)
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev = self.training
self.eval()
ilens = [x.shape[0]]
# subsample frame
x = x[:: self.subsample[0], :]
h = to_device(self, to_torch_tensor(x).float())
# make a utt list (1) to use the same interface for encoder
hs = h.contiguous().unsqueeze(0)
# 0. Frontend
if self.frontend is not None:
hs, hlens, mask = self.frontend(hs, ilens)
hlens_n = [None] * self.num_spkrs
for i in range(self.num_spkrs):
hs[i], hlens_n[i] = self.feature_transform(hs[i], hlens)
hlens = hlens_n
else:
hs, hlens = hs, ilens
# 1. Encoder
if not isinstance(hs, list): # single-channel multi-speaker input x
hs, hlens, _ = self.enc(hs, hlens)
else: # multi-channel multi-speaker input x
for i in range(self.num_spkrs):
hs[i], hlens[i], _ = self.enc(hs[i], hlens[i])
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = [self.ctc.log_softmax(i)[0] for i in hs]
else:
lpz = None
# 2. decoder
# decode the first utterance
y = [
self.dec.recognize_beam(
hs[i][0], lpz[i], recog_args, char_list, rnnlm, strm_idx=i
)
for i in range(self.num_spkrs)
]
if prev:
self.train()
return y
def recognize_batch(self, xs, recog_args, char_list, rnnlm=None):
"""E2E beam search.
:param ndarray xs: input acoustic feature (T, D)
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev = self.training
self.eval()
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
# subsample frame
xs = [xx[:: self.subsample[0], :] for xx in xs]
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
xs_pad = pad_list(xs, 0.0)
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(xs_pad, ilens)
hlens_n = [None] * self.num_spkrs
for i in range(self.num_spkrs):
hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens)
hlens = hlens_n
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
if not isinstance(hs_pad, list): # single-channel multi-speaker input x
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
else: # multi-channel multi-speaker input x
for i in range(self.num_spkrs):
hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i])
# calculate log P(z_t|X) for CTC scores
if recog_args.ctc_weight > 0.0:
lpz = [self.ctc.log_softmax(hs_pad[i]) for i in range(self.num_spkrs)]
normalize_score = False
else:
lpz = None
normalize_score = True
# 2. decoder
y = [
self.dec.recognize_beam_batch(
hs_pad[i],
hlens[i],
lpz[i],
recog_args,
char_list,
rnnlm,
normalize_score=normalize_score,
strm_idx=i,
)
for i in range(self.num_spkrs)
]
if prev:
self.train()
return y
def enhance(self, xs):
"""Forward only the frontend stage.
:param ndarray xs: input acoustic feature (T, C, F)
"""
if self.frontend is None:
raise RuntimeError("Frontend doesn't exist")
prev = self.training
self.eval()
ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64)
# subsample frame
xs = [xx[:: self.subsample[0], :] for xx in xs]
xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs]
xs_pad = pad_list(xs, 0.0)
enhanced, hlensm, mask = self.frontend(xs_pad, ilens)
if prev:
self.train()
if isinstance(enhanced, (tuple, list)):
enhanced = list(enhanced)
mask = list(mask)
for idx in range(len(enhanced)): # number of speakers
enhanced[idx] = enhanced[idx].cpu().numpy()
mask[idx] = mask[idx].cpu().numpy()
return enhanced, mask, ilens
return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens
def calculate_all_attentions(self, xs_pad, ilens, ys_pad):
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, num_spkrs, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
with torch.no_grad():
# 0. Frontend
if self.frontend is not None:
hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens)
hlens_n = [None] * self.num_spkrs
for i in range(self.num_spkrs):
hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens)
hlens = hlens_n
else:
hs_pad, hlens = xs_pad, ilens
# 1. Encoder
if not isinstance(hs_pad, list): # single-channel multi-speaker input x
hs_pad, hlens, _ = self.enc(hs_pad, hlens)
else: # multi-channel multi-speaker input x
for i in range(self.num_spkrs):
hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i])
# Permutation
ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax)
if self.num_spkrs <= 3:
loss_ctc = torch.stack(
[
self.ctc(
hs_pad[i // self.num_spkrs],
hlens[i // self.num_spkrs],
ys_pad[i % self.num_spkrs],
)
for i in range(self.num_spkrs**2)
],
1,
) # (B, num_spkrs^2)
loss_ctc, min_perm = self.pit.pit_process(loss_ctc)
for i in range(ys_pad.size(1)): # B
ys_pad[:, i] = ys_pad[min_perm[i], i]
# 2. Decoder
att_ws = [
self.dec.calculate_all_attentions(
hs_pad[i], hlens[i], ys_pad[i], strm_idx=i
)
for i in range(self.num_spkrs)
]
return att_ws
class EncoderMix(torch.nn.Module):
"""Encoder module for the case of multi-speaker mixture speech.
:param str etype: type of encoder network
:param int idim: number of dimensions of encoder network
:param int elayers_sd:
number of layers of speaker differentiate part in encoder network
:param int elayers_rec:
number of layers of shared recognition part in encoder network
:param int eunits: number of lstm units of encoder network
:param int eprojs: number of projection units of encoder network
:param np.ndarray subsample: list of subsampling numbers
:param float dropout: dropout rate
:param int in_channel: number of input channels
:param int num_spkrs: number of number of speakers
"""
def __init__(
self,
etype,
idim,
elayers_sd,
elayers_rec,
eunits,
eprojs,
subsample,
dropout,
num_spkrs=2,
in_channel=1,
):
"""Initialize the encoder of single-channel multi-speaker ASR."""
super(EncoderMix, self).__init__()
typ = etype.lstrip("vgg").rstrip("p")
if typ not in ["lstm", "gru", "blstm", "bgru"]:
logging.error("Error: need to specify an appropriate encoder architecture")
if etype.startswith("vgg"):
if etype[-1] == "p":
self.enc_mix = torch.nn.ModuleList([VGG2L(in_channel)])
self.enc_sd = torch.nn.ModuleList(
[
torch.nn.ModuleList(
[
RNNP(
get_vgg2l_odim(idim, in_channel=in_channel),
elayers_sd,
eunits,
eprojs,
subsample[: elayers_sd + 1],
dropout,
typ=typ,
)
]
)
for i in range(num_spkrs)
]
)
self.enc_rec = torch.nn.ModuleList(
[
RNNP(
eprojs,
elayers_rec,
eunits,
eprojs,
subsample[elayers_sd:],
dropout,
typ=typ,
)
]
)
logging.info("Use CNN-VGG + B" + typ.upper() + "P for encoder")
else:
logging.error(
f"Error: need to specify an appropriate encoder architecture. "
f"Illegal name {etype}"
)
sys.exit()
else:
logging.error(
f"Error: need to specify an appropriate encoder architecture. "
f"Illegal name {etype}"
)
sys.exit()
self.num_spkrs = num_spkrs
def forward(self, xs_pad, ilens):
"""Encodermix forward.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: list: batch of hidden state sequences [num_spkrs x (B, Tmax, eprojs)]
:rtype: torch.Tensor
"""
# mixture encoder
for module in self.enc_mix:
xs_pad, ilens, _ = module(xs_pad, ilens)
# SD and Rec encoder
xs_pad_sd = [xs_pad for i in range(self.num_spkrs)]
ilens_sd = [ilens for i in range(self.num_spkrs)]
for ns in range(self.num_spkrs):
# Encoder_SD: speaker differentiate encoder
for module in self.enc_sd[ns]:
xs_pad_sd[ns], ilens_sd[ns], _ = module(xs_pad_sd[ns], ilens_sd[ns])
# Encoder_Rec: recognition encoder
for module in self.enc_rec:
xs_pad_sd[ns], ilens_sd[ns], _ = module(xs_pad_sd[ns], ilens_sd[ns])
# make mask to remove bias value in padded part
mask = to_device(xs_pad, make_pad_mask(ilens_sd[0]).unsqueeze(-1))
return [x.masked_fill(mask, 0.0) for x in xs_pad_sd], ilens_sd, None
def encoder_for(args, idim, subsample):
"""Construct the encoder."""
if getattr(args, "use_frontend", False): # use getattr to keep compatibility
# with frontend, the mixed speech are separated as streams for each speaker
return encoder_for_single(args, idim, subsample)
else:
return EncoderMix(
args.etype,
idim,
args.elayers_sd,
args.elayers,
args.eunits,
args.eprojs,
subsample,
args.dropout_rate,
args.num_spkrs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment