Commit 60a2c57a authored by sunzhq2's avatar sunzhq2 Committed by xuxo
Browse files

update conformer

parent 4a699441
import numpy as np
import torch
class SegmentStreamingE2E(object):
"""SegmentStreamingE2E constructor.
:param E2E e2e: E2E ASR object
:param recog_args: arguments for "recognize" method of E2E
"""
def __init__(self, e2e, recog_args, rnnlm=None):
self._e2e = e2e
self._recog_args = recog_args
self._char_list = e2e.char_list
self._rnnlm = rnnlm
self._e2e.eval()
self._blank_idx_in_char_list = -1
for idx in range(len(self._char_list)):
if self._char_list[idx] == self._e2e.blank:
self._blank_idx_in_char_list = idx
break
self._subsampling_factor = np.prod(e2e.subsample)
self._activates = 0
self._blank_dur = 0
self._previous_input = []
self._previous_encoder_recurrent_state = None
self._encoder_states = []
self._ctc_posteriors = []
assert (
self._recog_args.batchsize <= 1
), "SegmentStreamingE2E works only with batch size <= 1"
assert (
"b" not in self._e2e.etype
), "SegmentStreamingE2E works only with uni-directional encoders"
def accept_input(self, x):
"""Call this method each time a new batch of input is available."""
self._previous_input.extend(x)
h, ilen = self._e2e.subsample_frames(x)
# Run encoder and apply greedy search on CTC softmax output
h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
h.unsqueeze(0), ilen, self._previous_encoder_recurrent_state
)
z = self._e2e.ctc.argmax(h).squeeze(0)
if self._activates == 0 and z[0] != self._blank_idx_in_char_list:
self._activates = 1
# Rerun encoder with zero state at onset of detection
tail_len = self._subsampling_factor * (
self._recog_args.streaming_onset_margin + 1
)
h, ilen = self._e2e.subsample_frames(
np.reshape(
self._previous_input[-tail_len:], [-1, len(self._previous_input[0])]
)
)
h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
h.unsqueeze(0), ilen, None
)
hyp = None
if self._activates == 1:
self._encoder_states.extend(h.squeeze(0))
self._ctc_posteriors.extend(self._e2e.ctc.log_softmax(h).squeeze(0))
if z[0] == self._blank_idx_in_char_list:
self._blank_dur += 1
else:
self._blank_dur = 0
if self._blank_dur >= self._recog_args.streaming_min_blank_dur:
seg_len = (
len(self._encoder_states)
- self._blank_dur
+ self._recog_args.streaming_offset_margin
)
if seg_len > 0:
# Run decoder with a detected segment
h = torch.cat(self._encoder_states[:seg_len], dim=0).view(
-1, self._encoder_states[0].size(0)
)
if self._recog_args.ctc_weight > 0.0:
lpz = torch.cat(self._ctc_posteriors[:seg_len], dim=0).view(
-1, self._ctc_posteriors[0].size(0)
)
if self._recog_args.batchsize > 0:
lpz = lpz.unsqueeze(0)
normalize_score = False
else:
lpz = None
normalize_score = True
if self._recog_args.batchsize == 0:
hyp = self._e2e.dec.recognize_beam(
h, lpz, self._recog_args, self._char_list, self._rnnlm
)
else:
hlens = torch.tensor([h.shape[0]])
hyp = self._e2e.dec.recognize_beam_batch(
h.unsqueeze(0),
hlens,
lpz,
self._recog_args,
self._char_list,
self._rnnlm,
normalize_score=normalize_score,
)[0]
self._activates = 0
self._blank_dur = 0
tail_len = (
self._subsampling_factor
* self._recog_args.streaming_onset_margin
)
self._previous_input = self._previous_input[-tail_len:]
self._encoder_states = []
self._ctc_posteriors = []
return hyp
import torch
# TODO(pzelasko): Currently allows half-streaming only;
# needs streaming attention decoder implementation
class WindowStreamingE2E(object):
"""WindowStreamingE2E constructor.
:param E2E e2e: E2E ASR object
:param recog_args: arguments for "recognize" method of E2E
"""
def __init__(self, e2e, recog_args, rnnlm=None):
self._e2e = e2e
self._recog_args = recog_args
self._char_list = e2e.char_list
self._rnnlm = rnnlm
self._e2e.eval()
self._offset = 0
self._previous_encoder_recurrent_state = None
self._encoder_states = []
self._ctc_posteriors = []
self._last_recognition = None
assert (
self._recog_args.ctc_weight > 0.0
), "WindowStreamingE2E works only with combined CTC and attention decoders."
def accept_input(self, x):
"""Call this method each time a new batch of input is available."""
h, ilen = self._e2e.subsample_frames(x)
# Streaming encoder
h, _, self._previous_encoder_recurrent_state = self._e2e.enc(
h.unsqueeze(0), ilen, self._previous_encoder_recurrent_state
)
self._encoder_states.append(h.squeeze(0))
# CTC posteriors for the incoming audio
self._ctc_posteriors.append(self._e2e.ctc.log_softmax(h).squeeze(0))
def _input_window_for_decoder(self, use_all=False):
if use_all:
return (
torch.cat(self._encoder_states, dim=0),
torch.cat(self._ctc_posteriors, dim=0),
)
def select_unprocessed_windows(window_tensors):
last_offset = self._offset
offset_traversed = 0
selected_windows = []
for es in window_tensors:
if offset_traversed > last_offset:
selected_windows.append(es)
continue
offset_traversed += es.size(1)
return torch.cat(selected_windows, dim=0)
return (
select_unprocessed_windows(self._encoder_states),
select_unprocessed_windows(self._ctc_posteriors),
)
def decode_with_attention_offline(self):
"""Run the attention decoder offline.
Works even if the previous layers (encoder and CTC decoder) were
being run in the online mode.
This method should be run after all the audio has been consumed.
This is used mostly to compare the results between offline
and online implementation of the previous layers.
"""
h, lpz = self._input_window_for_decoder(use_all=True)
return self._e2e.dec.recognize_beam(
h, lpz, self._recog_args, self._char_list, self._rnnlm
)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""CBHG related modules."""
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
class CBHGLoss(torch.nn.Module):
"""Loss function module for CBHG."""
def __init__(self, use_masking=True):
"""Initialize CBHG loss module.
Args:
use_masking (bool): Whether to mask padded part in loss calculation.
"""
super(CBHGLoss, self).__init__()
self.use_masking = use_masking
def forward(self, cbhg_outs, spcs, olens):
"""Calculate forward propagation.
Args:
cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim).
spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim).
olens (LongTensor): Batch of the lengths of each sequence (B,).
Returns:
Tensor: L1 loss value
Tensor: Mean square error loss value.
"""
# perform masking for padded values
if self.use_masking:
mask = make_non_pad_mask(olens).unsqueeze(-1).to(spcs.device)
spcs = spcs.masked_select(mask)
cbhg_outs = cbhg_outs.masked_select(mask)
# calculate loss
cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs)
cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs)
return cbhg_l1_loss, cbhg_mse_loss
class CBHG(torch.nn.Module):
"""CBHG module to convert log Mel-filterbanks to linear spectrogram.
This is a module of CBHG introduced
in `Tacotron: Towards End-to-End Speech Synthesis`_.
The CBHG converts the sequence of log Mel-filterbanks into linear spectrogram.
.. _`Tacotron: Towards End-to-End Speech Synthesis`:
https://arxiv.org/abs/1703.10135
"""
def __init__(
self,
idim,
odim,
conv_bank_layers=8,
conv_bank_chans=128,
conv_proj_filts=3,
conv_proj_chans=256,
highway_layers=4,
highway_units=128,
gru_units=256,
):
"""Initialize CBHG module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
conv_bank_layers (int, optional): The number of convolution bank layers.
conv_bank_chans (int, optional): The number of channels in convolution bank.
conv_proj_filts (int, optional):
Kernel size of convolutional projection layer.
conv_proj_chans (int, optional):
The number of channels in convolutional projection layer.
highway_layers (int, optional): The number of highway network layers.
highway_units (int, optional): The number of highway network units.
gru_units (int, optional): The number of GRU units (for both directions).
"""
super(CBHG, self).__init__()
self.idim = idim
self.odim = odim
self.conv_bank_layers = conv_bank_layers
self.conv_bank_chans = conv_bank_chans
self.conv_proj_filts = conv_proj_filts
self.conv_proj_chans = conv_proj_chans
self.highway_layers = highway_layers
self.highway_units = highway_units
self.gru_units = gru_units
# define 1d convolution bank
self.conv_bank = torch.nn.ModuleList()
for k in range(1, self.conv_bank_layers + 1):
if k % 2 != 0:
padding = (k - 1) // 2
else:
padding = ((k - 1) // 2, (k - 1) // 2 + 1)
self.conv_bank += [
torch.nn.Sequential(
torch.nn.ConstantPad1d(padding, 0.0),
torch.nn.Conv1d(
idim, self.conv_bank_chans, k, stride=1, padding=0, bias=True
),
torch.nn.BatchNorm1d(self.conv_bank_chans),
torch.nn.ReLU(),
)
]
# define max pooling (need padding for one-side to keep same length)
self.max_pool = torch.nn.Sequential(
torch.nn.ConstantPad1d((0, 1), 0.0), torch.nn.MaxPool1d(2, stride=1)
)
# define 1d convolution projection
self.projections = torch.nn.Sequential(
torch.nn.Conv1d(
self.conv_bank_chans * self.conv_bank_layers,
self.conv_proj_chans,
self.conv_proj_filts,
stride=1,
padding=(self.conv_proj_filts - 1) // 2,
bias=True,
),
torch.nn.BatchNorm1d(self.conv_proj_chans),
torch.nn.ReLU(),
torch.nn.Conv1d(
self.conv_proj_chans,
self.idim,
self.conv_proj_filts,
stride=1,
padding=(self.conv_proj_filts - 1) // 2,
bias=True,
),
torch.nn.BatchNorm1d(self.idim),
)
# define highway network
self.highways = torch.nn.ModuleList()
self.highways += [torch.nn.Linear(idim, self.highway_units)]
for _ in range(self.highway_layers):
self.highways += [HighwayNet(self.highway_units)]
# define bidirectional GRU
self.gru = torch.nn.GRU(
self.highway_units,
gru_units // 2,
num_layers=1,
batch_first=True,
bidirectional=True,
)
# define final projection
self.output = torch.nn.Linear(gru_units, odim, bias=True)
def forward(self, xs, ilens):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequences of inputs (B, Tmax, idim).
ilens (LongTensor): Batch of lengths of each input sequence (B,).
Return:
Tensor: Batch of the padded sequence of outputs (B, Tmax, odim).
LongTensor: Batch of lengths of each output sequence (B,).
"""
xs = xs.transpose(1, 2) # (B, idim, Tmax)
convs = []
for k in range(self.conv_bank_layers):
convs += [self.conv_bank[k](xs)]
convs = torch.cat(convs, dim=1) # (B, #CH * #BANK, Tmax)
convs = self.max_pool(convs)
convs = self.projections(convs).transpose(1, 2) # (B, Tmax, idim)
xs = xs.transpose(1, 2) + convs
# + 1 for dimension adjustment layer
for i in range(self.highway_layers + 1):
xs = self.highways[i](xs)
# sort by length
xs, ilens, sort_idx = self._sort_by_length(xs, ilens)
# total_length needs for DataParallel
# (see https://github.com/pytorch/pytorch/pull/6327)
total_length = xs.size(1)
if not isinstance(ilens, torch.Tensor):
ilens = torch.tensor(ilens)
xs = pack_padded_sequence(xs, ilens.cpu(), batch_first=True)
self.gru.flatten_parameters()
xs, _ = self.gru(xs)
xs, ilens = pad_packed_sequence(xs, batch_first=True, total_length=total_length)
# revert sorting by length
xs, ilens = self._revert_sort_by_length(xs, ilens, sort_idx)
xs = self.output(xs) # (B, Tmax, odim)
return xs, ilens
def inference(self, x):
"""Inference.
Args:
x (Tensor): The sequences of inputs (T, idim).
Return:
Tensor: The sequence of outputs (T, odim).
"""
assert len(x.size()) == 2
xs = x.unsqueeze(0)
ilens = x.new([x.size(0)]).long()
return self.forward(xs, ilens)[0][0]
def _sort_by_length(self, xs, ilens):
sort_ilens, sort_idx = ilens.sort(0, descending=True)
return xs[sort_idx], ilens[sort_idx], sort_idx
def _revert_sort_by_length(self, xs, ilens, sort_idx):
_, revert_idx = sort_idx.sort(0)
return xs[revert_idx], ilens[revert_idx]
class HighwayNet(torch.nn.Module):
"""Highway Network module.
This is a module of Highway Network introduced in `Highway Networks`_.
.. _`Highway Networks`: https://arxiv.org/abs/1505.00387
"""
def __init__(self, idim):
"""Initialize Highway Network module.
Args:
idim (int): Dimension of the inputs.
"""
super(HighwayNet, self).__init__()
self.idim = idim
self.projection = torch.nn.Sequential(
torch.nn.Linear(idim, idim), torch.nn.ReLU()
)
self.gate = torch.nn.Sequential(torch.nn.Linear(idim, idim), torch.nn.Sigmoid())
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of inputs (B, ..., idim).
Returns:
Tensor: Batch of outputs, which are the same shape as inputs (B, ..., idim).
"""
proj = self.projection(x)
gate = self.gate(x)
return proj * gate + x * (1.0 - gate)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Tacotron2 decoder related modules."""
import torch
import torch.nn.functional as F
from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA
def decoder_init(m):
"""Initialize decoder parameters."""
if isinstance(m, torch.nn.Conv1d):
torch.nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh"))
class ZoneOutCell(torch.nn.Module):
"""ZoneOut Cell module.
This is a module of zoneout described in
`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`_.
This code is modified from `eladhoffer/seq2seq.pytorch`_.
Examples:
>>> lstm = torch.nn.LSTMCell(16, 32)
>>> lstm = ZoneOutCell(lstm, 0.5)
.. _`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`:
https://arxiv.org/abs/1606.01305
.. _`eladhoffer/seq2seq.pytorch`:
https://github.com/eladhoffer/seq2seq.pytorch
"""
def __init__(self, cell, zoneout_rate=0.1):
"""Initialize zone out cell module.
Args:
cell (torch.nn.Module): Pytorch recurrent cell module
e.g. `torch.nn.Module.LSTMCell`.
zoneout_rate (float, optional): Probability of zoneout from 0.0 to 1.0.
"""
super(ZoneOutCell, self).__init__()
self.cell = cell
self.hidden_size = cell.hidden_size
self.zoneout_rate = zoneout_rate
if zoneout_rate > 1.0 or zoneout_rate < 0.0:
raise ValueError(
"zoneout probability must be in the range from 0.0 to 1.0."
)
def forward(self, inputs, hidden):
"""Calculate forward propagation.
Args:
inputs (Tensor): Batch of input tensor (B, input_size).
hidden (tuple):
- Tensor: Batch of initial hidden states (B, hidden_size).
- Tensor: Batch of initial cell states (B, hidden_size).
Returns:
tuple:
- Tensor: Batch of next hidden states (B, hidden_size).
- Tensor: Batch of next cell states (B, hidden_size).
"""
next_hidden = self.cell(inputs, hidden)
next_hidden = self._zoneout(hidden, next_hidden, self.zoneout_rate)
return next_hidden
def _zoneout(self, h, next_h, prob):
# apply recursively
if isinstance(h, tuple):
num_h = len(h)
if not isinstance(prob, tuple):
prob = tuple([prob] * num_h)
return tuple(
[self._zoneout(h[i], next_h[i], prob[i]) for i in range(num_h)]
)
if self.training:
mask = h.new(*h.size()).bernoulli_(prob)
return mask * h + (1 - mask) * next_h
else:
return prob * h + (1 - prob) * next_h
class Prenet(torch.nn.Module):
"""Prenet module for decoder of Spectrogram prediction network.
This is a module of Prenet in the decoder of Spectrogram prediction network,
which described in `Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_.
The Prenet preforms nonlinear conversion
of inputs before input to auto-regressive lstm,
which helps to learn diagonal attentions.
Note:
This module alway applies dropout even in evaluation.
See the detail in `Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions`_.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5):
"""Initialize prenet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of prenet layers.
n_units (int, optional): The number of prenet units.
"""
super(Prenet, self).__init__()
self.dropout_rate = dropout_rate
self.prenet = torch.nn.ModuleList()
for layer in range(n_layers):
n_inputs = idim if layer == 0 else n_units
self.prenet += [
torch.nn.Sequential(torch.nn.Linear(n_inputs, n_units), torch.nn.ReLU())
]
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., idim).
Returns:
Tensor: Batch of output tensors (B, ..., odim).
"""
for i in range(len(self.prenet)):
# we make this part non deterministic. See the above note.
x = F.dropout(self.prenet[i](x), self.dropout_rate)
return x
class Postnet(torch.nn.Module):
"""Postnet module for Spectrogram prediction network.
This is a module of Postnet in Spectrogram prediction network,
which described in `Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions`_.
The Postnet predicts refines the predicted
Mel-filterbank of the decoder,
which helps to compensate the detail structure of spectrogram.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(
self,
idim,
odim,
n_layers=5,
n_chans=512,
n_filts=5,
dropout_rate=0.5,
use_batch_norm=True,
):
"""Initialize postnet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of layers.
n_filts (int, optional): The number of filter size.
n_units (int, optional): The number of filter channels.
use_batch_norm (bool, optional): Whether to use batch normalization..
dropout_rate (float, optional): Dropout rate..
"""
super(Postnet, self).__init__()
self.postnet = torch.nn.ModuleList()
for layer in range(n_layers - 1):
ichans = odim if layer == 0 else n_chans
ochans = odim if layer == n_layers - 1 else n_chans
if use_batch_norm:
self.postnet += [
torch.nn.Sequential(
torch.nn.Conv1d(
ichans,
ochans,
n_filts,
stride=1,
padding=(n_filts - 1) // 2,
bias=False,
),
torch.nn.BatchNorm1d(ochans),
torch.nn.Tanh(),
torch.nn.Dropout(dropout_rate),
)
]
else:
self.postnet += [
torch.nn.Sequential(
torch.nn.Conv1d(
ichans,
ochans,
n_filts,
stride=1,
padding=(n_filts - 1) // 2,
bias=False,
),
torch.nn.Tanh(),
torch.nn.Dropout(dropout_rate),
)
]
ichans = n_chans if n_layers != 1 else odim
if use_batch_norm:
self.postnet += [
torch.nn.Sequential(
torch.nn.Conv1d(
ichans,
odim,
n_filts,
stride=1,
padding=(n_filts - 1) // 2,
bias=False,
),
torch.nn.BatchNorm1d(odim),
torch.nn.Dropout(dropout_rate),
)
]
else:
self.postnet += [
torch.nn.Sequential(
torch.nn.Conv1d(
ichans,
odim,
n_filts,
stride=1,
padding=(n_filts - 1) // 2,
bias=False,
),
torch.nn.Dropout(dropout_rate),
)
]
def forward(self, xs):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax).
Returns:
Tensor: Batch of padded output tensor. (B, odim, Tmax).
"""
for i in range(len(self.postnet)):
xs = self.postnet[i](xs)
return xs
class Decoder(torch.nn.Module):
"""Decoder module of Spectrogram prediction network.
This is a module of decoder of Spectrogram prediction network in Tacotron2,
which described in `Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_.
The decoder generates the sequence of
features from the sequence of the hidden states.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(
self,
idim,
odim,
att,
dlayers=2,
dunits=1024,
prenet_layers=2,
prenet_units=256,
postnet_layers=5,
postnet_chans=512,
postnet_filts=5,
output_activation_fn=None,
cumulate_att_w=True,
use_batch_norm=True,
use_concate=True,
dropout_rate=0.5,
zoneout_rate=0.1,
reduction_factor=1,
):
"""Initialize Tacotron2 decoder module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
att (torch.nn.Module): Instance of attention class.
dlayers (int, optional): The number of decoder lstm layers.
dunits (int, optional): The number of decoder lstm units.
prenet_layers (int, optional): The number of prenet layers.
prenet_units (int, optional): The number of prenet units.
postnet_layers (int, optional): The number of postnet layers.
postnet_filts (int, optional): The number of postnet filter size.
postnet_chans (int, optional): The number of postnet filter channels.
output_activation_fn (torch.nn.Module, optional):
Activation function for outputs.
cumulate_att_w (bool, optional):
Whether to cumulate previous attention weight.
use_batch_norm (bool, optional): Whether to use batch normalization.
use_concate (bool, optional): Whether to concatenate encoder embedding
with decoder lstm outputs.
dropout_rate (float, optional): Dropout rate.
zoneout_rate (float, optional): Zoneout rate.
reduction_factor (int, optional): Reduction factor.
"""
super(Decoder, self).__init__()
# store the hyperparameters
self.idim = idim
self.odim = odim
self.att = att
self.output_activation_fn = output_activation_fn
self.cumulate_att_w = cumulate_att_w
self.use_concate = use_concate
self.reduction_factor = reduction_factor
# check attention type
if isinstance(self.att, AttForwardTA):
self.use_att_extra_inputs = True
else:
self.use_att_extra_inputs = False
# define lstm network
prenet_units = prenet_units if prenet_layers != 0 else odim
self.lstm = torch.nn.ModuleList()
for layer in range(dlayers):
iunits = idim + prenet_units if layer == 0 else dunits
lstm = torch.nn.LSTMCell(iunits, dunits)
if zoneout_rate > 0.0:
lstm = ZoneOutCell(lstm, zoneout_rate)
self.lstm += [lstm]
# define prenet
if prenet_layers > 0:
self.prenet = Prenet(
idim=odim,
n_layers=prenet_layers,
n_units=prenet_units,
dropout_rate=dropout_rate,
)
else:
self.prenet = None
# define postnet
if postnet_layers > 0:
self.postnet = Postnet(
idim=idim,
odim=odim,
n_layers=postnet_layers,
n_chans=postnet_chans,
n_filts=postnet_filts,
use_batch_norm=use_batch_norm,
dropout_rate=dropout_rate,
)
else:
self.postnet = None
# define projection layers
iunits = idim + dunits if use_concate else dunits
self.feat_out = torch.nn.Linear(iunits, odim * reduction_factor, bias=False)
self.prob_out = torch.nn.Linear(iunits, reduction_factor)
# initialize
self.apply(decoder_init)
def _zero_state(self, hs):
init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size)
return init_hs
def forward(self, hs, hlens, ys):
"""Calculate forward propagation.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor):
Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
Tensor: Batch of output tensors after postnet (B, Lmax, odim).
Tensor: Batch of output tensors before postnet (B, Lmax, odim).
Tensor: Batch of logits of stop prediction (B, Lmax).
Tensor: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if self.reduction_factor > 1:
ys = ys[:, self.reduction_factor - 1 :: self.reduction_factor]
# length list should be list of int
hlens = list(map(int, hlens))
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(hs.size(0), self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# loop for an output sequence
outs, logits, att_ws = [], [], []
for y in ys.transpose(0, 1):
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out)
else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w)
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for i in range(1, len(self.lstm)):
z_list[i], c_list[i] = self.lstm[i](
z_list[i - 1], (z_list[i], c_list[i])
)
zcs = (
torch.cat([z_list[-1], att_c], dim=1)
if self.use_concate
else z_list[-1]
)
outs += [self.feat_out(zcs).view(hs.size(0), self.odim, -1)]
logits += [self.prob_out(zcs)]
att_ws += [att_w]
prev_out = y # teacher forcing
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
logits = torch.cat(logits, dim=1) # (B, Lmax)
before_outs = torch.cat(outs, dim=2) # (B, odim, Lmax)
att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax)
if self.reduction_factor > 1:
before_outs = before_outs.view(
before_outs.size(0), self.odim, -1
) # (B, odim, Lmax)
if self.postnet is not None:
after_outs = before_outs + self.postnet(before_outs) # (B, odim, Lmax)
else:
after_outs = before_outs
before_outs = before_outs.transpose(2, 1) # (B, Lmax, odim)
after_outs = after_outs.transpose(2, 1) # (B, Lmax, odim)
logits = logits
# apply activation function for scaling
if self.output_activation_fn is not None:
before_outs = self.output_activation_fn(before_outs)
after_outs = self.output_activation_fn(after_outs)
return after_outs, before_outs, logits, att_ws
def inference(
self,
h,
threshold=0.5,
minlenratio=0.0,
maxlenratio=10.0,
use_att_constraint=False,
backward_window=None,
forward_window=None,
):
"""Generate the sequence of features given the sequences of characters.
Args:
h (Tensor): Input sequence of encoder hidden states (T, C).
threshold (float, optional): Threshold to stop generation.
minlenratio (float, optional): Minimum length ratio.
If set to 1.0 and the length of input is 10,
the minimum length of outputs will be 10 * 1 = 10.
minlenratio (float, optional): Minimum length ratio.
If set to 10 and the length of input is 10,
the maximum length of outputs will be 10 * 10 = 100.
use_att_constraint (bool):
Whether to apply attention constraint introduced in `Deep Voice 3`_.
backward_window (int): Backward window size in attention constraint.
forward_window (int): Forward window size in attention constraint.
Returns:
Tensor: Output sequence of features (L, odim).
Tensor: Output sequence of stop probabilities (L,).
Tensor: Attention weights (L, T).
Note:
This computation is performed in auto-regressive manner.
.. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654
"""
# setup
assert len(h.size()) == 2
hs = h.unsqueeze(0)
ilens = [h.size(0)]
maxlen = int(h.size(0) * maxlenratio)
minlen = int(h.size(0) * minlenratio)
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(1, self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# setup for attention constraint
if use_att_constraint:
last_attended_idx = 0
else:
last_attended_idx = None
# loop for an output sequence
idx = 0
outs, att_ws, probs = [], [], []
while True:
# updated index
idx += self.reduction_factor
# decoder calculation
if self.use_att_extra_inputs:
att_c, att_w = self.att(
hs,
ilens,
z_list[0],
prev_att_w,
prev_out,
last_attended_idx=last_attended_idx,
backward_window=backward_window,
forward_window=forward_window,
)
else:
att_c, att_w = self.att(
hs,
ilens,
z_list[0],
prev_att_w,
last_attended_idx=last_attended_idx,
backward_window=backward_window,
forward_window=forward_window,
)
att_ws += [att_w]
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for i in range(1, len(self.lstm)):
z_list[i], c_list[i] = self.lstm[i](
z_list[i - 1], (z_list[i], c_list[i])
)
zcs = (
torch.cat([z_list[-1], att_c], dim=1)
if self.use_concate
else z_list[-1]
)
outs += [self.feat_out(zcs).view(1, self.odim, -1)] # [(1, odim, r), ...]
probs += [torch.sigmoid(self.prob_out(zcs))[0]] # [(r), ...]
if self.output_activation_fn is not None:
prev_out = self.output_activation_fn(outs[-1][:, :, -1]) # (1, odim)
else:
prev_out = outs[-1][:, :, -1] # (1, odim)
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
if use_att_constraint:
last_attended_idx = int(att_w.argmax())
# check whether to finish generation
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
# check mininum length
if idx < minlen:
continue
outs = torch.cat(outs, dim=2) # (1, odim, L)
if self.postnet is not None:
outs = outs + self.postnet(outs) # (1, odim, L)
outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
probs = torch.cat(probs, dim=0)
att_ws = torch.cat(att_ws, dim=0)
break
if self.output_activation_fn is not None:
outs = self.output_activation_fn(outs)
return outs, probs, att_ws
def calculate_all_attentions(self, hs, hlens, ys):
"""Calculate all of the attention weights.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor):
Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
numpy.ndarray: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if self.reduction_factor > 1:
ys = ys[:, self.reduction_factor - 1 :: self.reduction_factor]
# length list should be list of int
hlens = list(map(int, hlens))
# initialize hidden states of decoder
c_list = [self._zero_state(hs)]
z_list = [self._zero_state(hs)]
for _ in range(1, len(self.lstm)):
c_list += [self._zero_state(hs)]
z_list += [self._zero_state(hs)]
prev_out = hs.new_zeros(hs.size(0), self.odim)
# initialize attention
prev_att_w = None
self.att.reset()
# loop for an output sequence
att_ws = []
for y in ys.transpose(0, 1):
if self.use_att_extra_inputs:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out)
else:
att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w)
att_ws += [att_w]
prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out
xs = torch.cat([att_c, prenet_out], dim=1)
z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0]))
for i in range(1, len(self.lstm)):
z_list[i], c_list[i] = self.lstm[i](
z_list[i - 1], (z_list[i], c_list[i])
)
prev_out = y # teacher forcing
if self.cumulate_att_w and prev_att_w is not None:
prev_att_w = prev_att_w + att_w # Note: error when use +=
else:
prev_att_w = att_w
att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax)
return att_ws
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Tacotron2 encoder related modules."""
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
def encoder_init(m):
"""Initialize encoder parameters."""
if isinstance(m, torch.nn.Conv1d):
torch.nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu"))
class Encoder(torch.nn.Module):
"""Encoder module of Spectrogram prediction network.
This is a module of encoder of Spectrogram prediction network in Tacotron2,
which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel
Spectrogram Predictions`_. This is the encoder which converts either a sequence
of characters or acoustic features into the sequence of hidden states.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def __init__(
self,
idim,
input_layer="embed",
embed_dim=512,
elayers=1,
eunits=512,
econv_layers=3,
econv_chans=512,
econv_filts=5,
use_batch_norm=True,
use_residual=False,
dropout_rate=0.5,
padding_idx=0,
):
"""Initialize Tacotron2 encoder module.
Args:
idim (int) Dimension of the inputs.
input_layer (str): Input layer type.
embed_dim (int, optional) Dimension of character embedding.
elayers (int, optional) The number of encoder blstm layers.
eunits (int, optional) The number of encoder blstm units.
econv_layers (int, optional) The number of encoder conv layers.
econv_filts (int, optional) The number of encoder conv filter size.
econv_chans (int, optional) The number of encoder conv filter channels.
use_batch_norm (bool, optional) Whether to use batch normalization.
use_residual (bool, optional) Whether to use residual connection.
dropout_rate (float, optional) Dropout rate.
"""
super(Encoder, self).__init__()
# store the hyperparameters
self.idim = idim
self.use_residual = use_residual
# define network layer modules
if input_layer == "linear":
self.embed = torch.nn.Linear(idim, econv_chans)
elif input_layer == "embed":
self.embed = torch.nn.Embedding(idim, embed_dim, padding_idx=padding_idx)
else:
raise ValueError("unknown input_layer: " + input_layer)
if econv_layers > 0:
self.convs = torch.nn.ModuleList()
for layer in range(econv_layers):
ichans = (
embed_dim if layer == 0 and input_layer == "embed" else econv_chans
)
if use_batch_norm:
self.convs += [
torch.nn.Sequential(
torch.nn.Conv1d(
ichans,
econv_chans,
econv_filts,
stride=1,
padding=(econv_filts - 1) // 2,
bias=False,
),
torch.nn.BatchNorm1d(econv_chans),
torch.nn.ReLU(),
torch.nn.Dropout(dropout_rate),
)
]
else:
self.convs += [
torch.nn.Sequential(
torch.nn.Conv1d(
ichans,
econv_chans,
econv_filts,
stride=1,
padding=(econv_filts - 1) // 2,
bias=False,
),
torch.nn.ReLU(),
torch.nn.Dropout(dropout_rate),
)
]
else:
self.convs = None
if elayers > 0:
iunits = econv_chans if econv_layers != 0 else embed_dim
self.blstm = torch.nn.LSTM(
iunits, eunits // 2, elayers, batch_first=True, bidirectional=True
)
else:
self.blstm = None
# initialize
self.apply(encoder_init)
def forward(self, xs, ilens=None):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequence. Either character ids (B, Tmax)
or acoustic feature (B, Tmax, idim * encoder_reduction_factor). Padded
value should be 0.
ilens (LongTensor): Batch of lengths of each input batch (B,).
Returns:
Tensor: Batch of the sequences of encoder states(B, Tmax, eunits).
LongTensor: Batch of lengths of each sequence (B,)
"""
xs = self.embed(xs).transpose(1, 2)
if self.convs is not None:
for i in range(len(self.convs)):
if self.use_residual:
xs = xs + self.convs[i](xs)
else:
xs = self.convs[i](xs)
if self.blstm is None:
return xs.transpose(1, 2)
if not isinstance(ilens, torch.Tensor):
ilens = torch.tensor(ilens)
xs = pack_padded_sequence(xs.transpose(1, 2), ilens.cpu(), batch_first=True)
self.blstm.flatten_parameters()
xs, _ = self.blstm(xs) # (B, Tmax, C)
xs, hlens = pad_packed_sequence(xs, batch_first=True)
return xs, hlens
def inference(self, x):
"""Inference.
Args:
x (Tensor): The sequeunce of character ids (T,)
or acoustic feature (T, idim * encoder_reduction_factor).
Returns:
Tensor: The sequences of encoder states(T, eunits).
"""
xs = x.unsqueeze(0)
ilens = torch.tensor([x.size(0)])
return self.forward(xs, ilens)[0][0]
"""Transducer model arguments."""
import ast
from argparse import _ArgumentGroup
from distutils.util import strtobool
def add_encoder_general_arguments(group: _ArgumentGroup) -> _ArgumentGroup:
"""Define general arguments for encoder."""
group.add_argument(
"--etype",
default="blstmp",
type=str,
choices=[
"custom",
"lstm",
"blstm",
"lstmp",
"blstmp",
"vgglstmp",
"vggblstmp",
"vgglstm",
"vggblstm",
"gru",
"bgru",
"grup",
"bgrup",
"vgggrup",
"vggbgrup",
"vgggru",
"vggbgru",
],
help="Type of encoder network architecture",
)
group.add_argument(
"--dropout-rate",
default=0.0,
type=float,
help="Dropout rate for the encoder",
)
return group
def add_rnn_encoder_arguments(group: _ArgumentGroup) -> _ArgumentGroup:
"""Define arguments for RNN encoder."""
group.add_argument(
"--elayers",
default=4,
type=int,
help="Number of encoder layers (for shared recognition part "
"in multi-speaker asr mode)",
)
group.add_argument(
"--eunits",
"-u",
default=300,
type=int,
help="Number of encoder hidden units",
)
group.add_argument(
"--eprojs", default=320, type=int, help="Number of encoder projection units"
)
group.add_argument(
"--subsample",
default="1",
type=str,
help="Subsample input frames x_y_z means subsample every x frame "
"at 1st layer, every y frame at 2nd layer etc.",
)
return group
def add_custom_encoder_arguments(group: _ArgumentGroup) -> _ArgumentGroup:
"""Define arguments for Custom encoder."""
group.add_argument(
"--enc-block-arch",
type=eval,
action="append",
default=None,
help="Encoder architecture definition by blocks",
)
group.add_argument(
"--enc-block-repeat",
default=1,
type=int,
help="Repeat N times the provided encoder blocks if N > 1",
)
group.add_argument(
"--custom-enc-input-layer",
type=str,
default="conv2d",
choices=["conv2d", "vgg2l", "linear", "embed"],
help="Custom encoder input layer type",
)
group.add_argument(
"--custom-enc-input-dropout-rate",
type=float,
default=0.0,
help="Dropout rate of custom encoder input layer",
)
group.add_argument(
"--custom-enc-input-pos-enc-dropout-rate",
type=float,
default=0.0,
help="Dropout rate of positional encoding in custom encoder input layer",
)
group.add_argument(
"--custom-enc-positional-encoding-type",
type=str,
default="abs_pos",
choices=["abs_pos", "scaled_abs_pos", "rel_pos"],
help="Custom encoder positional encoding layer type",
)
group.add_argument(
"--custom-enc-self-attn-type",
type=str,
default="self_attn",
choices=["self_attn", "rel_self_attn"],
help="Custom encoder self-attention type",
)
group.add_argument(
"--custom-enc-pw-activation-type",
type=str,
default="relu",
choices=["relu", "hardtanh", "selu", "swish"],
help="Custom encoder pointwise activation type",
)
group.add_argument(
"--custom-enc-conv-mod-activation-type",
type=str,
default="swish",
choices=["relu", "hardtanh", "selu", "swish"],
help="Custom encoder convolutional module activation type",
)
return group
def add_decoder_general_arguments(group: _ArgumentGroup) -> _ArgumentGroup:
"""Define general arguments for encoder."""
group.add_argument(
"--dtype",
default="lstm",
type=str,
choices=["lstm", "gru", "custom"],
help="Type of decoder to use",
)
group.add_argument(
"--dropout-rate-decoder",
default=0.0,
type=float,
help="Dropout rate for the decoder",
)
group.add_argument(
"--dropout-rate-embed-decoder",
default=0.0,
type=float,
help="Dropout rate for the decoder embedding layer",
)
return group
def add_rnn_decoder_arguments(group: _ArgumentGroup) -> _ArgumentGroup:
"""Define arguments for RNN decoder."""
group.add_argument(
"--dec-embed-dim",
default=320,
type=int,
help="Number of decoder embeddings dimensions",
)
group.add_argument(
"--dlayers", default=1, type=int, help="Number of decoder layers"
)
group.add_argument(
"--dunits", default=320, type=int, help="Number of decoder hidden units"
)
return group
def add_custom_decoder_arguments(group: _ArgumentGroup) -> _ArgumentGroup:
"""Define arguments for Custom decoder."""
group.add_argument(
"--dec-block-arch",
type=eval,
action="append",
default=None,
help="Custom decoder blocks definition",
)
group.add_argument(
"--dec-block-repeat",
default=1,
type=int,
help="Repeat N times the provided decoder blocks if N > 1",
)
group.add_argument(
"--custom-dec-input-layer",
type=str,
default="embed",
choices=["linear", "embed"],
help="Custom decoder input layer type",
)
group.add_argument(
"--custom-dec-pw-activation-type",
type=str,
default="relu",
choices=["relu", "hardtanh", "selu", "swish"],
help="Custom decoder pointwise activation type",
)
return group
def add_custom_training_arguments(group: _ArgumentGroup) -> _ArgumentGroup:
"""Define arguments for training with Custom architecture."""
group.add_argument(
"--optimizer-warmup-steps",
default=25000,
type=int,
help="Optimizer warmup steps",
)
group.add_argument(
"--noam-lr",
default=10.0,
type=float,
help="Initial value of learning rate",
)
group.add_argument(
"--noam-adim",
default=0,
type=int,
help="Most dominant attention dimension for scheduler.",
)
group.add_argument(
"--transformer-warmup-steps",
type=int,
help="Optimizer warmup steps. The parameter is deprecated, "
"please use --optimizer-warmup-steps instead.",
dest="optimizer_warmup_steps",
)
group.add_argument(
"--transformer-lr",
type=float,
help="Initial value of learning rate. The parameter is deprecated, "
"please use --noam-lr instead.",
dest="noam_lr",
)
group.add_argument(
"--adim",
type=int,
help="Most dominant attention dimension for scheduler. "
"The parameter is deprecated, please use --noam-adim instead.",
dest="noam_adim",
)
return group
def add_transducer_arguments(group: _ArgumentGroup) -> _ArgumentGroup:
"""Define general arguments for Transducer model."""
group.add_argument(
"--transducer-weight",
default=1.0,
type=float,
help="Weight of main Transducer loss.",
)
group.add_argument(
"--joint-dim",
default=320,
type=int,
help="Number of dimensions in joint space",
)
group.add_argument(
"--joint-activation-type",
type=str,
default="tanh",
choices=["relu", "tanh", "swish"],
help="Joint network activation type",
)
group.add_argument(
"--score-norm",
type=strtobool,
nargs="?",
default=True,
help="Normalize Transducer scores by length",
)
group.add_argument(
"--fastemit-lambda",
default=0.0,
type=float,
help="Regularization parameter for FastEmit (https://arxiv.org/abs/2010.11148)",
)
return group
def add_auxiliary_task_arguments(group: _ArgumentGroup) -> _ArgumentGroup:
"""Add arguments for auxiliary task."""
group.add_argument(
"--use-ctc-loss",
type=strtobool,
nargs="?",
default=False,
help="Whether to compute auxiliary CTC loss.",
)
group.add_argument(
"--ctc-loss-weight",
default=0.5,
type=float,
help="Weight of auxiliary CTC loss.",
)
group.add_argument(
"--ctc-loss-dropout-rate",
default=0.0,
type=float,
help="Dropout rate for auxiliary CTC.",
)
group.add_argument(
"--use-lm-loss",
type=strtobool,
nargs="?",
default=False,
help="Whether to compute auxiliary LM loss (label smoothing).",
)
group.add_argument(
"--lm-loss-weight",
default=0.5,
type=float,
help="Weight of auxiliary LM loss.",
)
group.add_argument(
"--lm-loss-smoothing-rate",
default=0.0,
type=float,
help="Smoothing rate for LM loss. If > 0, label smoothing is enabled.",
)
group.add_argument(
"--use-aux-transducer-loss",
type=strtobool,
nargs="?",
default=False,
help="Whether to compute auxiliary Transducer loss.",
)
group.add_argument(
"--aux-transducer-loss-weight",
default=0.2,
type=float,
help="Weight of auxiliary Transducer loss.",
)
group.add_argument(
"--aux-transducer-loss-enc-output-layers",
default=None,
type=ast.literal_eval,
help="List of intermediate encoder layers for auxiliary "
"transducer loss computation.",
)
group.add_argument(
"--aux-transducer-loss-mlp-dim",
default=320,
type=int,
help="Multilayer perceptron hidden dimension for auxiliary Transducer loss.",
)
group.add_argument(
"--aux-transducer-loss-mlp-dropout-rate",
default=0.0,
type=float,
help="Multilayer perceptron dropout rate for auxiliary Transducer loss.",
)
group.add_argument(
"--use-symm-kl-div-loss",
type=strtobool,
nargs="?",
default=False,
help="Whether to compute symmetric KL divergence loss.",
)
group.add_argument(
"--symm-kl-div-loss-weight",
default=0.2,
type=float,
help="Weight of symmetric KL divergence loss.",
)
return group
"""Set of methods to create custom architecture."""
from typing import Any, Dict, List, Tuple, Union
import torch
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule
from espnet.nets.pytorch_backend.conformer.encoder_layer import (
EncoderLayer as ConformerEncoderLayer,
)
from espnet.nets.pytorch_backend.nets_utils import get_activation
from espnet.nets.pytorch_backend.transducer.conv1d_nets import CausalConv1d, Conv1d
from espnet.nets.pytorch_backend.transducer.transformer_decoder_layer import (
TransformerDecoderLayer,
)
from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L
from espnet.nets.pytorch_backend.transformer.attention import (
MultiHeadedAttention,
RelPositionMultiHeadedAttention,
)
from espnet.nets.pytorch_backend.transformer.embedding import (
PositionalEncoding,
RelPositionalEncoding,
ScaledPositionalEncoding,
)
from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
from espnet.nets.pytorch_backend.transformer.repeat import MultiSequential
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
def verify_block_arguments(
net_part: str,
block: Dict[str, Any],
num_block: int,
) -> Tuple[int, int]:
"""Verify block arguments are valid.
Args:
net_part: Network part, either 'encoder' or 'decoder'.
block: Block parameters.
num_block: Block ID.
Return:
block_io: Input and output dimension of the block.
"""
block_type = block.get("type")
if block_type is None:
raise ValueError(
"Block %d in %s doesn't a type assigned.", (num_block, net_part)
)
if block_type == "transformer":
arguments = {"d_hidden", "d_ff", "heads"}
elif block_type == "conformer":
arguments = {
"d_hidden",
"d_ff",
"heads",
"macaron_style",
"use_conv_mod",
}
if net_part == "decoder":
raise ValueError("Decoder does not support 'conformer'.")
if block.get("use_conv_mod", None) is True and "conv_mod_kernel" not in block:
raise ValueError(
"Block %d: 'use_conv_mod' is True but "
" 'conv_mod_kernel' is not specified" % num_block
)
elif block_type == "causal-conv1d":
arguments = {"idim", "odim", "kernel_size"}
if net_part == "encoder":
raise ValueError("Encoder does not support 'causal-conv1d'.")
elif block_type == "conv1d":
arguments = {"idim", "odim", "kernel_size"}
if net_part == "decoder":
raise ValueError("Decoder does not support 'conv1d.'")
else:
raise NotImplementedError(
"Wrong type. Currently supported: "
"causal-conv1d, conformer, conv-nd or transformer."
)
if not arguments.issubset(block):
raise ValueError(
"%s in %s in position %d: Expected block arguments : %s."
" See tutorial page for more information."
% (block_type, net_part, num_block, arguments)
)
if block_type in ("transformer", "conformer"):
block_io = (block["d_hidden"], block["d_hidden"])
else:
block_io = (block["idim"], block["odim"])
return block_io
def prepare_input_layer(
input_layer_type: str,
feats_dim: int,
blocks: List[Dict[str, Any]],
dropout_rate: float,
pos_enc_dropout_rate: float,
) -> Dict[str, Any]:
"""Prepare input layer arguments.
Args:
input_layer_type: Input layer type.
feats_dim: Dimension of input features.
blocks: Blocks parameters for network part.
dropout_rate: Dropout rate for input layer.
pos_enc_dropout_rate: Dropout rate for input layer pos. enc.
Return:
input_block: Input block parameters.
"""
input_block = {}
first_block_type = blocks[0].get("type", None)
if first_block_type == "causal-conv1d":
input_block["type"] = "c-embed"
else:
input_block["type"] = input_layer_type
input_block["dropout-rate"] = dropout_rate
input_block["pos-dropout-rate"] = pos_enc_dropout_rate
input_block["idim"] = feats_dim
if first_block_type in ("transformer", "conformer"):
input_block["odim"] = blocks[0].get("d_hidden", 0)
else:
input_block["odim"] = blocks[0].get("idim", 0)
return input_block
def prepare_body_model(
net_part: str,
blocks: List[Dict[str, Any]],
) -> Tuple[int]:
"""Prepare model body blocks.
Args:
net_part: Network part, either 'encoder' or 'decoder'.
blocks: Blocks parameters for network part.
Return:
: Network output dimension.
"""
cmp_io = [
verify_block_arguments(net_part, b, (i + 1)) for i, b in enumerate(blocks)
]
if {"transformer", "conformer"} <= {b["type"] for b in blocks}:
raise NotImplementedError(
net_part + ": transformer and conformer blocks "
"can't be used together in the same net part."
)
for i in range(1, len(cmp_io)):
if cmp_io[(i - 1)][1] != cmp_io[i][0]:
raise ValueError(
"Output/Input mismatch between blocks %d and %d in %s"
% (i, (i + 1), net_part)
)
return cmp_io[-1][1]
def get_pos_enc_and_att_class(
net_part: str, pos_enc_type: str, self_attn_type: str
) -> Tuple[
Union[PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding],
Union[MultiHeadedAttention, RelPositionMultiHeadedAttention],
]:
"""Get positional encoding and self attention module class.
Args:
net_part: Network part, either 'encoder' or 'decoder'.
pos_enc_type: Positional encoding type.
self_attn_type: Self-attention type.
Return:
pos_enc_class: Positional encoding class.
self_attn_class: Self-attention class.
"""
if pos_enc_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_type == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_type == "rel_pos":
if net_part == "encoder" and self_attn_type != "rel_self_attn":
raise ValueError("'rel_pos' is only compatible with 'rel_self_attn'")
pos_enc_class = RelPositionalEncoding
else:
raise NotImplementedError(
"pos_enc_type should be either 'abs_pos', 'scaled_abs_pos' or 'rel_pos'"
)
if self_attn_type == "rel_self_attn":
self_attn_class = RelPositionMultiHeadedAttention
else:
self_attn_class = MultiHeadedAttention
return pos_enc_class, self_attn_class
def build_input_layer(
block: Dict[str, Any],
pos_enc_class: torch.nn.Module,
padding_idx: int,
) -> Tuple[Union[Conv2dSubsampling, VGG2L, torch.nn.Sequential], int]:
"""Build input layer.
Args:
block: Architecture definition of input layer.
pos_enc_class: Positional encoding class.
padding_idx: Padding symbol ID for embedding layer (if provided).
Returns:
: Input layer module.
subsampling_factor: Subsampling factor.
"""
input_type = block["type"]
idim = block["idim"]
odim = block["odim"]
dropout_rate = block["dropout-rate"]
pos_dropout_rate = block["pos-dropout-rate"]
if pos_enc_class.__name__ == "RelPositionalEncoding":
pos_enc_class_subsampling = pos_enc_class(odim, pos_dropout_rate)
else:
pos_enc_class_subsampling = None
if input_type == "linear":
return (
torch.nn.Sequential(
torch.nn.Linear(idim, odim),
torch.nn.LayerNorm(odim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(odim, pos_dropout_rate),
),
1,
)
elif input_type == "conv2d":
return Conv2dSubsampling(idim, odim, dropout_rate, pos_enc_class_subsampling), 4
elif input_type == "vgg2l":
return VGG2L(idim, odim, pos_enc_class_subsampling), 4
elif input_type == "embed":
return (
torch.nn.Sequential(
torch.nn.Embedding(idim, odim, padding_idx=padding_idx),
pos_enc_class(odim, pos_dropout_rate),
),
1,
)
elif input_type == "c-embed":
return (
torch.nn.Sequential(
torch.nn.Embedding(idim, odim, padding_idx=padding_idx),
torch.nn.Dropout(dropout_rate),
),
1,
)
else:
raise NotImplementedError(
"Invalid input layer: %s. Supported: linear, conv2d, vgg2l and embed"
% input_type
)
def build_transformer_block(
net_part: str,
block: Dict[str, Any],
pw_layer_type: str,
pw_activation_type: str,
) -> Union[EncoderLayer, TransformerDecoderLayer]:
"""Build function for transformer block.
Args:
net_part: Network part, either 'encoder' or 'decoder'.
block: Transformer block parameters.
pw_layer_type: Positionwise layer type.
pw_activation_type: Positionwise activation type.
Returns:
: Function to create transformer (encoder or decoder) block.
"""
d_hidden = block["d_hidden"]
dropout_rate = block.get("dropout-rate", 0.0)
pos_dropout_rate = block.get("pos-dropout-rate", 0.0)
att_dropout_rate = block.get("att-dropout-rate", 0.0)
if pw_layer_type != "linear":
raise NotImplementedError(
"Transformer block only supports linear pointwise layer."
)
if net_part == "encoder":
transformer_layer_class = EncoderLayer
elif net_part == "decoder":
transformer_layer_class = TransformerDecoderLayer
return lambda: transformer_layer_class(
d_hidden,
MultiHeadedAttention(block["heads"], d_hidden, att_dropout_rate),
PositionwiseFeedForward(
d_hidden,
block["d_ff"],
pos_dropout_rate,
get_activation(pw_activation_type),
),
dropout_rate,
)
def build_conformer_block(
block: Dict[str, Any],
self_attn_class: str,
pw_layer_type: str,
pw_activation_type: str,
conv_mod_activation_type: str,
) -> ConformerEncoderLayer:
"""Build function for conformer block.
Args:
block: Conformer block parameters.
self_attn_type: Self-attention module type.
pw_layer_type: Positionwise layer type.
pw_activation_type: Positionwise activation type.
conv_mod_activation_type: Convolutional module activation type.
Returns:
: Function to create conformer (encoder) block.
"""
d_hidden = block["d_hidden"]
d_ff = block["d_ff"]
dropout_rate = block.get("dropout-rate", 0.0)
pos_dropout_rate = block.get("pos-dropout-rate", 0.0)
att_dropout_rate = block.get("att-dropout-rate", 0.0)
macaron_style = block["macaron_style"]
use_conv_mod = block["use_conv_mod"]
if pw_layer_type == "linear":
pw_layer = PositionwiseFeedForward
pw_layer_args = (
d_hidden,
d_ff,
pos_dropout_rate,
get_activation(pw_activation_type),
)
else:
raise NotImplementedError("Conformer block only supports linear yet.")
if macaron_style:
macaron_net = PositionwiseFeedForward
macaron_net_args = (
d_hidden,
d_ff,
pos_dropout_rate,
get_activation(pw_activation_type),
)
if use_conv_mod:
conv_mod = ConvolutionModule
conv_mod_args = (
d_hidden,
block["conv_mod_kernel"],
get_activation(conv_mod_activation_type),
)
return lambda: ConformerEncoderLayer(
d_hidden,
self_attn_class(block["heads"], d_hidden, att_dropout_rate),
pw_layer(*pw_layer_args),
macaron_net(*macaron_net_args) if macaron_style else None,
conv_mod(*conv_mod_args) if use_conv_mod else None,
dropout_rate,
)
def build_conv1d_block(block: Dict[str, Any], block_type: str) -> CausalConv1d:
"""Build function for causal conv1d block.
Args:
block: CausalConv1d or Conv1D block parameters.
Returns:
: Function to create conv1d (encoder) or causal conv1d (decoder) block.
"""
if block_type == "conv1d":
conv_class = Conv1d
else:
conv_class = CausalConv1d
stride = block.get("stride", 1)
dilation = block.get("dilation", 1)
groups = block.get("groups", 1)
bias = block.get("bias", True)
use_batch_norm = block.get("use-batch-norm", False)
use_relu = block.get("use-relu", False)
dropout_rate = block.get("dropout-rate", 0.0)
return lambda: conv_class(
block["idim"],
block["odim"],
block["kernel_size"],
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
relu=use_relu,
batch_norm=use_batch_norm,
dropout_rate=dropout_rate,
)
def build_blocks(
net_part: str,
idim: int,
input_layer_type: str,
blocks: List[Dict[str, Any]],
repeat_block: int = 0,
self_attn_type: str = "self_attn",
positional_encoding_type: str = "abs_pos",
positionwise_layer_type: str = "linear",
positionwise_activation_type: str = "relu",
conv_mod_activation_type: str = "relu",
input_layer_dropout_rate: float = 0.0,
input_layer_pos_enc_dropout_rate: float = 0.0,
padding_idx: int = -1,
) -> Tuple[
Union[Conv2dSubsampling, VGG2L, torch.nn.Sequential], MultiSequential, int, int
]:
"""Build custom model blocks.
Args:
net_part: Network part, either 'encoder' or 'decoder'.
idim: Input dimension.
input_layer: Input layer type.
blocks: Blocks parameters for network part.
repeat_block: Number of times provided blocks are repeated.
positional_encoding_type: Positional encoding layer type.
positionwise_layer_type: Positionwise layer type.
positionwise_activation_type: Positionwise activation type.
conv_mod_activation_type: Convolutional module activation type.
input_layer_dropout_rate: Dropout rate for input layer.
input_layer_pos_enc_dropout_rate: Dropout rate for input layer pos. enc.
padding_idx: Padding symbol ID for embedding layer.
Returns:
in_layer: Input layer
all_blocks: Encoder/Decoder network.
out_dim: Network output dimension.
conv_subsampling_factor: Subsampling factor in frontend CNN.
"""
fn_modules = []
pos_enc_class, self_attn_class = get_pos_enc_and_att_class(
net_part, positional_encoding_type, self_attn_type
)
input_block = prepare_input_layer(
input_layer_type,
idim,
blocks,
input_layer_dropout_rate,
input_layer_pos_enc_dropout_rate,
)
out_dim = prepare_body_model(net_part, blocks)
input_layer, conv_subsampling_factor = build_input_layer(
input_block,
pos_enc_class,
padding_idx,
)
for i in range(len(blocks)):
block_type = blocks[i]["type"]
if block_type in ("causal-conv1d", "conv1d"):
module = build_conv1d_block(blocks[i], block_type)
elif block_type == "conformer":
module = build_conformer_block(
blocks[i],
self_attn_class,
positionwise_layer_type,
positionwise_activation_type,
conv_mod_activation_type,
)
elif block_type == "transformer":
module = build_transformer_block(
net_part,
blocks[i],
positionwise_layer_type,
positionwise_activation_type,
)
fn_modules.append(module)
if repeat_block > 1:
fn_modules = fn_modules * repeat_block
return (
input_layer,
MultiSequential(*[fn() for fn in fn_modules]),
out_dim,
conv_subsampling_factor,
)
"""Convolution networks definition for custom archictecture."""
from typing import Optional, Tuple, Union
import torch
class Conv1d(torch.nn.Module):
"""1D convolution module for custom encoder.
Args:
idim: Input dimension.
odim: Output dimension.
kernel_size: Size of the convolving kernel.
stride: Stride of the convolution.
dilation: Spacing between the kernel points.
groups: Number of blocked connections from input channels to output channels.
bias: Whether to add a learnable bias to the output.
batch_norm: Whether to use batch normalization after convolution.
relu: Whether to use a ReLU activation after convolution.
dropout_rate: Dropout rate.
"""
def __init__(
self,
idim: int,
odim: int,
kernel_size: Union[int, Tuple],
stride: Union[int, Tuple] = 1,
dilation: Union[int, Tuple] = 1,
groups: Union[int, Tuple] = 1,
bias: bool = True,
batch_norm: bool = False,
relu: bool = True,
dropout_rate: float = 0.0,
):
"""Construct a Conv1d module object."""
super().__init__()
self.conv = torch.nn.Conv1d(
idim,
odim,
kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
)
self.dropout = torch.nn.Dropout(p=dropout_rate)
if relu:
self.relu_func = torch.nn.ReLU()
if batch_norm:
self.bn = torch.nn.BatchNorm1d(odim)
self.relu = relu
self.batch_norm = batch_norm
self.padding = dilation * (kernel_size - 1)
self.stride = stride
self.out_pos = torch.nn.Linear(idim, odim)
def forward(
self,
sequence: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]:
"""Forward ConvEncoderLayer module object.
Args:
sequence: Input sequences.
(B, T, D_in)
or (B, T, D_in), (B, 2 * (T - 1), D_att)
mask: Mask of input sequences. (B, 1, T)
Returns:
sequence: Output sequences.
(B, sub(T), D_out)
or (B, sub(T), D_out), (B, 2 * (sub(T) - 1), D_att)
mask: Mask of output sequences. (B, 1, sub(T))
"""
if isinstance(sequence, tuple):
sequence, pos_embed = sequence[0], sequence[1]
else:
sequence, pos_embed = sequence, None
sequence = sequence.transpose(1, 2)
sequence = self.conv(sequence)
if self.batch_norm:
sequence = self.bn(sequence)
sequence = self.dropout(sequence)
if self.relu:
sequence = self.relu_func(sequence)
sequence = sequence.transpose(1, 2)
mask = self.create_new_mask(mask)
if pos_embed is not None:
pos_embed = self.create_new_pos_embed(pos_embed)
return (sequence, pos_embed), mask
return sequence, mask
def create_new_mask(self, mask: torch.Tensor) -> torch.Tensor:
"""Create new mask.
Args:
mask: Mask of input sequences. (B, 1, T)
Returns:
mask: Mask of output sequences. (B, 1, sub(T))
"""
if mask is None:
return mask
if self.padding != 0:
mask = mask[:, :, : -self.padding]
mask = mask[:, :, :: self.stride]
return mask
def create_new_pos_embed(self, pos_embed: torch.Tensor) -> torch.Tensor:
"""Create new positional embedding vector.
Args:
pos_embed: Input sequences positional embedding.
(B, 2 * (T - 1), D_att)
Return:
pos_embed: Output sequences positional embedding.
(B, 2 * (sub(T) - 1), D_att)
"""
pos_embed_positive = pos_embed[:, : pos_embed.size(1) // 2 + 1, :]
pos_embed_negative = pos_embed[:, pos_embed.size(1) // 2 :, :]
if self.padding != 0:
pos_embed_positive = pos_embed_positive[:, : -self.padding, :]
pos_embed_negative = pos_embed_negative[:, : -self.padding, :]
pos_embed_positive = pos_embed_positive[:, :: self.stride, :]
pos_embed_negative = pos_embed_negative[:, :: self.stride, :]
pos_embed = torch.cat([pos_embed_positive, pos_embed_negative[:, 1:, :]], dim=1)
return self.out_pos(pos_embed)
class CausalConv1d(torch.nn.Module):
"""1D causal convolution module for custom decoder.
Args:
idim: Input dimension.
odim: Output dimension.
kernel_size: Size of the convolving kernel.
stride: Stride of the convolution.
dilation: Spacing between the kernel points.
groups: Number of blocked connections from input channels to output channels.
bias: Whether to add a learnable bias to the output.
batch_norm: Whether to apply batch normalization.
relu: Whether to pass final output through ReLU activation.
dropout_rate: Dropout rate.
"""
def __init__(
self,
idim: int,
odim: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
batch_norm: bool = False,
relu: bool = True,
dropout_rate: float = 0.0,
):
"""Construct a CausalConv1d object."""
super().__init__()
self.padding = (kernel_size - 1) * dilation
self.causal_conv1d = torch.nn.Conv1d(
idim,
odim,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
dilation=dilation,
groups=groups,
bias=bias,
)
self.dropout = torch.nn.Dropout(p=dropout_rate)
if batch_norm:
self.bn = torch.nn.BatchNorm1d(odim)
if relu:
self.relu_func = torch.nn.ReLU()
self.batch_norm = batch_norm
self.relu = relu
def forward(
self,
sequence: torch.Tensor,
mask: torch.Tensor,
cache: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward CausalConv1d for custom decoder.
Args:
sequence: CausalConv1d input sequences. (B, U, D_in)
mask: Mask of CausalConv1d input sequences. (B, 1, U)
Returns:
sequence: CausalConv1d output sequences. (B, sub(U), D_out)
mask: Mask of CausalConv1d output sequences. (B, 1, sub(U))
"""
sequence = sequence.transpose(1, 2)
sequence = self.causal_conv1d(sequence)
if self.padding != 0:
sequence = sequence[:, :, : -self.padding]
if self.batch_norm:
sequence = self.bn(sequence)
sequence = self.dropout(sequence)
if self.relu:
sequence = self.relu_func(sequence)
sequence = sequence.transpose(1, 2)
return sequence, mask
"""Custom decoder definition for Transducer model."""
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from espnet.nets.pytorch_backend.transducer.blocks import build_blocks
from espnet.nets.pytorch_backend.transducer.utils import (
check_batch_states,
check_state,
pad_sequence,
)
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.transducer_decoder_interface import (
ExtendedHypothesis,
Hypothesis,
TransducerDecoderInterface,
)
class CustomDecoder(TransducerDecoderInterface, torch.nn.Module):
"""Custom decoder module for Transducer model.
Args:
odim: Output dimension.
dec_arch: Decoder block architecture (type and parameters).
input_layer: Input layer type.
repeat_block: Number of times dec_arch is repeated.
joint_activation_type: Type of activation for joint network.
positional_encoding_type: Positional encoding type.
positionwise_layer_type: Positionwise layer type.
positionwise_activation_type: Positionwise activation type.
input_layer_dropout_rate: Dropout rate for input layer.
blank_id: Blank symbol ID.
"""
def __init__(
self,
odim: int,
dec_arch: List,
input_layer: str = "embed",
repeat_block: int = 0,
joint_activation_type: str = "tanh",
positional_encoding_type: str = "abs_pos",
positionwise_layer_type: str = "linear",
positionwise_activation_type: str = "relu",
input_layer_dropout_rate: float = 0.0,
blank_id: int = 0,
):
"""Construct a CustomDecoder object."""
torch.nn.Module.__init__(self)
self.embed, self.decoders, ddim, _ = build_blocks(
"decoder",
odim,
input_layer,
dec_arch,
repeat_block=repeat_block,
positional_encoding_type=positional_encoding_type,
positionwise_layer_type=positionwise_layer_type,
positionwise_activation_type=positionwise_activation_type,
input_layer_dropout_rate=input_layer_dropout_rate,
padding_idx=blank_id,
)
self.after_norm = LayerNorm(ddim)
self.dlayers = len(self.decoders)
self.dunits = ddim
self.odim = odim
self.blank_id = blank_id
def set_device(self, device: torch.device):
"""Set GPU device to use.
Args:
device: Device ID.
"""
self.device = device
def init_state(
self,
batch_size: Optional[int] = None,
) -> List[Optional[torch.Tensor]]:
"""Initialize decoder states.
Args:
batch_size: Batch size.
Returns:
state: Initial decoder hidden states. [N x None]
"""
state = [None] * self.dlayers
return state
def forward(
self, dec_input: torch.Tensor, dec_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Encode label ID sequences.
Args:
dec_input: Label ID sequences. (B, U)
dec_mask: Label mask sequences. (B, U)
Return:
dec_output: Decoder output sequences. (B, U, D_dec)
dec_output_mask: Mask of decoder output sequences. (B, U)
"""
dec_input = self.embed(dec_input)
dec_output, dec_mask = self.decoders(dec_input, dec_mask)
dec_output = self.after_norm(dec_output)
return dec_output, dec_mask
def score(
self, hyp: Hypothesis, cache: Dict[str, Any]
) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]:
"""One-step forward hypothesis.
Args:
hyp: Hypothesis.
cache: Pairs of (dec_out, dec_state) for each label sequence. (key)
Returns:
dec_out: Decoder output sequence. (1, D_dec)
dec_state: Decoder hidden states. [N x (1, U, D_dec)]
lm_label: Label ID for LM. (1,)
"""
labels = torch.tensor([hyp.yseq], device=self.device)
lm_label = labels[:, -1]
str_labels = "_".join(list(map(str, hyp.yseq)))
if str_labels in cache:
dec_out, dec_state = cache[str_labels]
else:
dec_out_mask = subsequent_mask(len(hyp.yseq)).unsqueeze_(0)
new_state = check_state(hyp.dec_state, (labels.size(1) - 1), self.blank_id)
dec_out = self.embed(labels)
dec_state = []
for s, decoder in zip(new_state, self.decoders):
dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s)
dec_state.append(dec_out)
dec_out = self.after_norm(dec_out[:, -1])
cache[str_labels] = (dec_out, dec_state)
return dec_out[0], dec_state, lm_label
def batch_score(
self,
hyps: Union[List[Hypothesis], List[ExtendedHypothesis]],
dec_states: List[Optional[torch.Tensor]],
cache: Dict[str, Any],
use_lm: bool,
) -> Tuple[torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]:
"""One-step forward hypotheses.
Args:
hyps: Hypotheses.
dec_states: Decoder hidden states. [N x (B, U, D_dec)]
cache: Pairs of (h_dec, dec_states) for each label sequences. (keys)
use_lm: Whether to compute label ID sequences for LM.
Returns:
dec_out: Decoder output sequences. (B, D_dec)
dec_states: Decoder hidden states. [N x (B, U, D_dec)]
lm_labels: Label ID sequences for LM. (B,)
"""
final_batch = len(hyps)
process = []
done = [None] * final_batch
for i, hyp in enumerate(hyps):
str_labels = "_".join(list(map(str, hyp.yseq)))
if str_labels in cache:
done[i] = cache[str_labels]
else:
process.append((str_labels, hyp.yseq, hyp.dec_state))
if process:
labels = pad_sequence([p[1] for p in process], self.blank_id)
labels = torch.LongTensor(labels, device=self.device)
p_dec_states = self.create_batch_states(
self.init_state(),
[p[2] for p in process],
labels,
)
dec_out = self.embed(labels)
dec_out_mask = (
subsequent_mask(labels.size(-1))
.unsqueeze_(0)
.expand(len(process), -1, -1)
)
new_states = []
for s, decoder in zip(p_dec_states, self.decoders):
dec_out, dec_out_mask = decoder(dec_out, dec_out_mask, cache=s)
new_states.append(dec_out)
dec_out = self.after_norm(dec_out[:, -1])
j = 0
for i in range(final_batch):
if done[i] is None:
state = self.select_state(new_states, j)
done[i] = (dec_out[j], state)
cache[process[j][0]] = (dec_out[j], state)
j += 1
dec_out = torch.stack([d[0] for d in done])
dec_states = self.create_batch_states(
dec_states, [d[1] for d in done], [[0] + h.yseq for h in hyps]
)
if use_lm:
lm_labels = torch.LongTensor(
[hyp.yseq[-1] for hyp in hyps], device=self.device
)
return dec_out, dec_states, lm_labels
return dec_out, dec_states, None
def select_state(
self, states: List[Optional[torch.Tensor]], idx: int
) -> List[Optional[torch.Tensor]]:
"""Get specified ID state from decoder hidden states.
Args:
states: Decoder hidden states. [N x (B, U, D_dec)]
idx: State ID to extract.
Returns:
state_idx: Decoder hidden state for given ID. [N x (1, U, D_dec)]
"""
if states[0] is None:
return states
state_idx = [states[layer][idx] for layer in range(self.dlayers)]
return state_idx
def create_batch_states(
self,
states: List[Optional[torch.Tensor]],
new_states: List[Optional[torch.Tensor]],
check_list: List[List[int]],
) -> List[Optional[torch.Tensor]]:
"""Create decoder hidden states sequences.
Args:
states: Decoder hidden states. [N x (B, U, D_dec)]
new_states: Decoder hidden states. [B x [N x (1, U, D_dec)]]
check_list: Label ID sequences.
Returns:
states: New decoder hidden states. [N x (B, U, D_dec)]
"""
if new_states[0][0] is None:
return states
max_len = max(len(elem) for elem in check_list) - 1
for layer in range(self.dlayers):
states[layer] = check_batch_states(
[s[layer] for s in new_states], max_len, self.blank_id
)
return states
"""Cutom encoder definition for transducer models."""
from typing import List, Tuple, Union
import torch
from espnet.nets.pytorch_backend.transducer.blocks import build_blocks
from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
class CustomEncoder(torch.nn.Module):
"""Custom encoder module for transducer models.
Args:
idim: Input dimension.
enc_arch: Encoder block architecture (type and parameters).
input_layer: Input layer type.
repeat_block: Number of times blocks_arch is repeated.
self_attn_type: Self-attention type.
positional_encoding_type: Positional encoding type.
positionwise_layer_type: Positionwise layer type.
positionwise_activation_type: Positionwise activation type.
conv_mod_activation_type: Convolutional module activation type.
aux_enc_output_layers: Layer IDs for auxiliary encoder output sequences.
input_layer_dropout_rate: Dropout rate for input layer.
input_layer_pos_enc_dropout_rate: Dropout rate for input layer pos. enc.
padding_idx: Padding symbol ID for embedding layer.
"""
def __init__(
self,
idim: int,
enc_arch: List,
input_layer: str = "linear",
repeat_block: int = 1,
self_attn_type: str = "selfattn",
positional_encoding_type: str = "abs_pos",
positionwise_layer_type: str = "linear",
positionwise_activation_type: str = "relu",
conv_mod_activation_type: str = "relu",
aux_enc_output_layers: List = [],
input_layer_dropout_rate: float = 0.0,
input_layer_pos_enc_dropout_rate: float = 0.0,
padding_idx: int = -1,
):
"""Construct an CustomEncoder object."""
super().__init__()
(
self.embed,
self.encoders,
self.enc_out,
self.conv_subsampling_factor,
) = build_blocks(
"encoder",
idim,
input_layer,
enc_arch,
repeat_block=repeat_block,
self_attn_type=self_attn_type,
positional_encoding_type=positional_encoding_type,
positionwise_layer_type=positionwise_layer_type,
positionwise_activation_type=positionwise_activation_type,
conv_mod_activation_type=conv_mod_activation_type,
input_layer_dropout_rate=input_layer_dropout_rate,
input_layer_pos_enc_dropout_rate=input_layer_pos_enc_dropout_rate,
padding_idx=padding_idx,
)
self.after_norm = LayerNorm(self.enc_out)
self.n_blocks = len(enc_arch) * repeat_block
self.aux_enc_output_layers = aux_enc_output_layers
def forward(
self,
feats: torch.Tensor,
mask: torch.Tensor,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor]]:
"""Encode feature sequences.
Args:
feats: Feature sequences. (B, F, D_feats)
feats_mask: Feature mask sequences. (B, 1, F)
Returns:
enc_out: Encoder output sequences. (B, T, D_enc) with/without
Auxiliary encoder output sequences. (B, T, D_enc_aux)
enc_out_mask: Mask for encoder output sequences. (B, 1, T) with/without
Mask for auxiliary encoder output sequences. (B, T, D_enc_aux)
"""
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
enc_out, mask = self.embed(feats, mask)
else:
enc_out = self.embed(feats)
if self.aux_enc_output_layers:
aux_custom_outputs = []
aux_custom_lens = []
for b in range(self.n_blocks):
enc_out, mask = self.encoders[b](enc_out, mask)
if b in self.aux_enc_output_layers:
if isinstance(enc_out, tuple):
aux_custom_output = enc_out[0]
else:
aux_custom_output = enc_out
aux_custom_outputs.append(self.after_norm(aux_custom_output))
aux_custom_lens.append(mask)
else:
enc_out, mask = self.encoders(enc_out, mask)
if isinstance(enc_out, tuple):
enc_out = enc_out[0]
enc_out = self.after_norm(enc_out)
if self.aux_enc_output_layers:
return (enc_out, aux_custom_outputs), (mask, aux_custom_lens)
return enc_out, mask
"""CER/WER computation for Transducer model."""
from typing import List, Tuple, Union
import torch
from espnet.nets.beam_search_transducer import BeamSearchTransducer
from espnet.nets.pytorch_backend.transducer.custom_decoder import CustomDecoder
from espnet.nets.pytorch_backend.transducer.joint_network import JointNetwork
from espnet.nets.pytorch_backend.transducer.rnn_decoder import RNNDecoder
class ErrorCalculator(object):
"""CER and WER computation for Transducer model.
Args:
decoder: Decoder module.
joint_network: Joint network module.
token_list: Set of unique labels.
sym_space: Space symbol.
sym_blank: Blank symbol.
report_cer: Whether to compute CER.
report_wer: Whether to compute WER.
"""
def __init__(
self,
decoder: Union[RNNDecoder, CustomDecoder],
joint_network: JointNetwork,
token_list: List[int],
sym_space: str,
sym_blank: str,
report_cer: bool = False,
report_wer: bool = False,
):
"""Construct an ErrorCalculator object for Transducer model."""
super().__init__()
self.beam_search = BeamSearchTransducer(
decoder=decoder,
joint_network=joint_network,
beam_size=2,
search_type="default",
)
self.decoder = decoder
self.token_list = token_list
self.space = sym_space
self.blank = sym_blank
self.report_cer = report_cer
self.report_wer = report_wer
def __call__(
self, enc_out: torch.Tensor, target: torch.Tensor
) -> Tuple[float, float]:
"""Calculate sentence-level CER/WER score for hypotheses sequences.
Args:
enc_out: Encoder output sequences. (B, T, D_enc)
target: Target label ID sequences. (B, L)
Returns:
cer: Sentence-level CER score.
wer: Sentence-level WER score.
"""
cer, wer = None, None
batchsize = int(enc_out.size(0))
batch_nbest = []
enc_out = enc_out.to(next(self.decoder.parameters()).device)
for b in range(batchsize):
nbest_hyps = self.beam_search(enc_out[b])
batch_nbest.append(nbest_hyps[-1])
batch_nbest = [nbest_hyp.yseq[1:] for nbest_hyp in batch_nbest]
hyps, refs = self.convert_to_char(batch_nbest, target.cpu())
if self.report_cer:
cer = self.calculate_cer(hyps, refs)
if self.report_wer:
wer = self.calculate_wer(hyps, refs)
return cer, wer
def convert_to_char(
self, hyps: torch.Tensor, refs: torch.Tensor
) -> Tuple[List, List]:
"""Convert label ID sequences to character.
Args:
hyps: Hypotheses sequences. (B, L)
refs: References sequences. (B, L)
Returns:
char_hyps: Character list of hypotheses.
char_hyps: Character list of references.
"""
char_hyps, char_refs = [], []
for i, hyp in enumerate(hyps):
hyp_i = [self.token_list[int(h)] for h in hyp]
ref_i = [self.token_list[int(r)] for r in refs[i]]
char_hyp = "".join(hyp_i).replace(self.space, " ")
char_hyp = char_hyp.replace(self.blank, "")
char_ref = "".join(ref_i).replace(self.space, " ")
char_hyps.append(char_hyp)
char_refs.append(char_ref)
return char_hyps, char_refs
def calculate_cer(self, hyps: torch.Tensor, refs: torch.Tensor) -> float:
"""Calculate sentence-level CER score.
Args:
hyps: Hypotheses sequences. (B, L)
refs: References sequences. (B, L)
Returns:
: Average sentence-level CER score.
"""
import editdistance
distances, lens = [], []
for i, hyp in enumerate(hyps):
char_hyp = hyp.replace(" ", "")
char_ref = refs[i].replace(" ", "")
distances.append(editdistance.eval(char_hyp, char_ref))
lens.append(len(char_ref))
return float(sum(distances)) / sum(lens)
def calculate_wer(self, hyps: torch.Tensor, refs: torch.Tensor) -> float:
"""Calculate sentence-level WER score.
Args:
hyps: Hypotheses sequences. (B, L)
refs: References sequences. (B, L)
Returns:
: Average sentence-level WER score.
"""
import editdistance
distances, lens = [], []
for i, hyp in enumerate(hyps):
word_hyp = hyp.split()
word_ref = refs[i].split()
distances.append(editdistance.eval(word_hyp, word_ref))
lens.append(len(word_ref))
return float(sum(distances)) / sum(lens)
"""Parameter initialization for Transducer model."""
import math
from argparse import Namespace
import torch
from espnet.nets.pytorch_backend.initialization import set_forget_bias_to_one
def initializer(model: torch.nn.Module, args: Namespace):
"""Initialize Transducer model.
Args:
model: Transducer model.
args: Namespace containing model options.
"""
for name, p in model.named_parameters():
if any(x in name for x in ["enc.", "dec.", "transducer_tasks."]):
if p.dim() == 1:
# bias
p.data.zero_()
elif p.dim() == 2:
# linear weight
n = p.size(1)
stdv = 1.0 / math.sqrt(n)
p.data.normal_(0, stdv)
elif p.dim() in (3, 4):
# conv weight
n = p.size(1)
for k in p.size()[2:]:
n *= k
stdv = 1.0 / math.sqrt(n)
p.data.normal_(0, stdv)
if args.dtype != "custom":
model.dec.embed.weight.data.normal_(0, 1)
for i in range(model.dec.dlayers):
set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_ih_l0"))
set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_hh_l0"))
"""Transducer joint network implementation."""
import torch
from espnet.nets.pytorch_backend.nets_utils import get_activation
class JointNetwork(torch.nn.Module):
"""Transducer joint network module.
Args:
joint_output_size: Joint network output dimension
encoder_output_size: Encoder output dimension.
decoder_output_size: Decoder output dimension.
joint_space_size: Dimension of joint space.
joint_activation_type: Type of activation for joint network.
"""
def __init__(
self,
joint_output_size: int,
encoder_output_size: int,
decoder_output_size: int,
joint_space_size: int,
joint_activation_type: int,
):
"""Joint network initializer."""
super().__init__()
self.lin_enc = torch.nn.Linear(encoder_output_size, joint_space_size)
self.lin_dec = torch.nn.Linear(
decoder_output_size, joint_space_size, bias=False
)
self.lin_out = torch.nn.Linear(joint_space_size, joint_output_size)
self.joint_activation = get_activation(joint_activation_type)
def forward(
self,
enc_out: torch.Tensor,
dec_out: torch.Tensor,
is_aux: bool = False,
quantization: bool = False,
) -> torch.Tensor:
"""Joint computation of encoder and decoder hidden state sequences.
Args:
enc_out: Expanded encoder output state sequences (B, T, 1, D_enc)
dec_out: Expanded decoder output state sequences (B, 1, U, D_dec)
is_aux: Whether auxiliary tasks in used.
quantization: Whether dynamic quantization is used.
Returns:
joint_out: Joint output state sequences. (B, T, U, D_out)
"""
if is_aux:
joint_out = self.joint_activation(enc_out + self.lin_dec(dec_out))
elif quantization:
joint_out = self.joint_activation(
self.lin_enc(enc_out.unsqueeze(0)) + self.lin_dec(dec_out.unsqueeze(0))
)
return self.lin_out(joint_out)[0]
else:
joint_out = self.joint_activation(
self.lin_enc(enc_out) + self.lin_dec(dec_out)
)
joint_out = self.lin_out(joint_out)
return joint_out
"""RNN decoder definition for Transducer model."""
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from espnet.nets.transducer_decoder_interface import (
ExtendedHypothesis,
Hypothesis,
TransducerDecoderInterface,
)
class RNNDecoder(TransducerDecoderInterface, torch.nn.Module):
"""RNN decoder module for Transducer model.
Args:
odim: Output dimension.
dtype: Decoder units type.
dlayers: Number of decoder layers.
dunits: Number of decoder units per layer..
embed_dim: Embedding layer dimension.
dropout_rate: Dropout rate for decoder layers.
dropout_rate_embed: Dropout rate for embedding layer.
blank_id: Blank symbol ID.
"""
def __init__(
self,
odim: int,
dtype: str,
dlayers: int,
dunits: int,
embed_dim: int,
dropout_rate: float = 0.0,
dropout_rate_embed: float = 0.0,
blank_id: int = 0,
):
"""Transducer initializer."""
super().__init__()
self.embed = torch.nn.Embedding(odim, embed_dim, padding_idx=blank_id)
self.dropout_embed = torch.nn.Dropout(p=dropout_rate_embed)
dec_net = torch.nn.LSTM if dtype == "lstm" else torch.nn.GRU
self.decoder = torch.nn.ModuleList(
[dec_net(embed_dim, dunits, 1, batch_first=True)]
)
self.dropout_dec = torch.nn.Dropout(p=dropout_rate)
for _ in range(1, dlayers):
self.decoder += [dec_net(dunits, dunits, 1, batch_first=True)]
self.dlayers = dlayers
self.dunits = dunits
self.dtype = dtype
self.odim = odim
self.ignore_id = -1
self.blank_id = blank_id
self.multi_gpus = torch.cuda.device_count() > 1
def set_device(self, device: torch.device):
"""Set GPU device to use.
Args:
device: Device ID.
"""
self.device = device
def init_state(
self, batch_size: int
) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
"""Initialize decoder states.
Args:
batch_size: Batch size.
Returns:
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
"""
h_n = torch.zeros(
self.dlayers,
batch_size,
self.dunits,
device=self.device,
)
if self.dtype == "lstm":
c_n = torch.zeros(
self.dlayers,
batch_size,
self.dunits,
device=self.device,
)
return (h_n, c_n)
return (h_n, None)
def rnn_forward(
self,
sequence: torch.Tensor,
state: Tuple[torch.Tensor, Optional[torch.Tensor]],
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Encode source label sequences.
Args:
sequence: RNN input sequences. (B, D_emb)
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
Returns:
sequence: RNN output sequences. (B, D_dec)
(h_next, c_next): Decoder hidden states. (N, B, D_dec), (N, B, D_dec))
"""
h_prev, c_prev = state
h_next, c_next = self.init_state(sequence.size(0))
for layer in range(self.dlayers):
if self.dtype == "lstm":
(
sequence,
(
h_next[layer : layer + 1],
c_next[layer : layer + 1],
),
) = self.decoder[layer](
sequence, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])
)
else:
sequence, h_next[layer : layer + 1] = self.decoder[layer](
sequence, hx=h_prev[layer : layer + 1]
)
sequence = self.dropout_dec(sequence)
return sequence, (h_next, c_next)
def forward(self, labels: torch.Tensor) -> torch.Tensor:
"""Encode source label sequences.
Args:
labels: Label ID sequences. (B, L)
Returns:
dec_out: Decoder output sequences. (B, T, U, D_dec)
"""
init_state = self.init_state(labels.size(0))
dec_embed = self.dropout_embed(self.embed(labels))
dec_out, _ = self.rnn_forward(dec_embed, init_state)
return dec_out
def score(
self, hyp: Hypothesis, cache: Dict[str, Any]
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
"""One-step forward hypothesis.
Args:
hyp: Hypothesis.
cache: Pairs of (dec_out, state) for each label sequence. (key)
Returns:
dec_out: Decoder output sequence. (1, D_dec)
new_state: Decoder hidden states. ((N, 1, D_dec), (N, 1, D_dec))
label: Label ID for LM. (1,)
"""
label = torch.full((1, 1), hyp.yseq[-1], dtype=torch.long, device=self.device)
str_labels = "_".join(list(map(str, hyp.yseq)))
if str_labels in cache:
dec_out, dec_state = cache[str_labels]
else:
dec_emb = self.embed(label)
dec_out, dec_state = self.rnn_forward(dec_emb, hyp.dec_state)
cache[str_labels] = (dec_out, dec_state)
return dec_out[0][0], dec_state, label[0]
def batch_score(
self,
hyps: Union[List[Hypothesis], List[ExtendedHypothesis]],
dec_states: Tuple[torch.Tensor, Optional[torch.Tensor]],
cache: Dict[str, Any],
use_lm: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""One-step forward hypotheses.
Args:
hyps: Hypotheses.
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
cache: Pairs of (dec_out, dec_states) for each label sequences. (keys)
use_lm: Whether to compute label ID sequences for LM.
Returns:
dec_out: Decoder output sequences. (B, D_dec)
dec_states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
lm_labels: Label ID sequences for LM. (B,)
"""
final_batch = len(hyps)
process = []
done = [None] * final_batch
for i, hyp in enumerate(hyps):
str_labels = "_".join(list(map(str, hyp.yseq)))
if str_labels in cache:
done[i] = cache[str_labels]
else:
process.append((str_labels, hyp.yseq[-1], hyp.dec_state))
if process:
labels = torch.LongTensor([[p[1]] for p in process], device=self.device)
p_dec_states = self.create_batch_states(
self.init_state(labels.size(0)), [p[2] for p in process]
)
dec_emb = self.embed(labels)
dec_out, new_states = self.rnn_forward(dec_emb, p_dec_states)
j = 0
for i in range(final_batch):
if done[i] is None:
state = self.select_state(new_states, j)
done[i] = (dec_out[j], state)
cache[process[j][0]] = (dec_out[j], state)
j += 1
dec_out = torch.cat([d[0] for d in done], dim=0)
dec_states = self.create_batch_states(dec_states, [d[1] for d in done])
if use_lm:
lm_labels = torch.LongTensor([h.yseq[-1] for h in hyps], device=self.device)
return dec_out, dec_states, lm_labels
return dec_out, dec_states, None
def select_state(
self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Get specified ID state from decoder hidden states.
Args:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
idx: State ID to extract.
Returns:
: Decoder hidden state for given ID.
((N, 1, D_dec), (N, 1, D_dec))
"""
return (
states[0][:, idx : idx + 1, :],
states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
)
def create_batch_states(
self,
states: Tuple[torch.Tensor, Optional[torch.Tensor]],
new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
check_list: Optional[List] = None,
) -> List[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Create decoder hidden states.
Args:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec))]
Returns:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
"""
return (
torch.cat([s[0] for s in new_states], dim=1),
torch.cat([s[1] for s in new_states], dim=1)
if self.dtype == "lstm"
else None,
)
"""RNN encoder implementation for Transducer model.
These classes are based on the ones in espnet.nets.pytorch_backend.rnn.encoders,
and modified to output intermediate representation based given list of layers as input.
To do so, RNN class rely on a stack of 1-layer LSTM instead of a multi-layer LSTM.
The additional outputs are intended to be used with Transducer auxiliary tasks.
"""
from argparse import Namespace
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from espnet.nets.e2e_asr_common import get_vgg2l_odim
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, to_device
class RNNP(torch.nn.Module):
"""RNN with projection layer module.
Args:
idim: Input dimension.
rnn_type: RNNP units type.
elayers: Number of RNNP layers.
eunits: Number of units ((2 * eunits) if bidirectional).
eprojs: Number of projection units.
subsample: Subsampling rate per layer.
dropout_rate: Dropout rate for RNNP layers.
aux_output_layers: Layer IDs for auxiliary RNNP output sequences.
"""
def __init__(
self,
idim: int,
rnn_type: str,
elayers: int,
eunits: int,
eprojs: int,
subsample: np.ndarray,
dropout_rate: float,
aux_output_layers: List = [],
):
"""Initialize RNNP module."""
super().__init__()
bidir = rnn_type[0] == "b"
for i in range(elayers):
if i == 0:
input_dim = idim
else:
input_dim = eprojs
rnn_layer = torch.nn.LSTM if "lstm" in rnn_type else torch.nn.GRU
rnn = rnn_layer(
input_dim, eunits, num_layers=1, bidirectional=bidir, batch_first=True
)
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
if bidir:
setattr(self, "bt%d" % i, torch.nn.Linear(2 * eunits, eprojs))
else:
setattr(self, "bt%d" % i, torch.nn.Linear(eunits, eprojs))
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.elayers = elayers
self.eunits = eunits
self.subsample = subsample
self.rnn_type = rnn_type
self.bidir = bidir
self.aux_output_layers = aux_output_layers
def forward(
self,
rnn_input: torch.Tensor,
rnn_len: torch.Tensor,
prev_states: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""RNNP forward.
Args:
rnn_input: RNN input sequences. (B, T, D_in)
rnn_len: RNN input sequences lengths. (B,)
prev_states: RNN hidden states. [N x (B, T, D_proj)]
Returns:
rnn_output : RNN output sequences. (B, T, D_proj)
with or without intermediate RNN output sequences.
((B, T, D_proj), [N x (B, T, D_proj)])
rnn_len: RNN output sequences lengths. (B,)
current_states: RNN hidden states. [N x (B, T, D_proj)]
"""
aux_rnn_outputs = []
aux_rnn_lens = []
current_states = []
for layer in range(self.elayers):
if not isinstance(rnn_len, torch.Tensor):
rnn_len = torch.tensor(rnn_len)
pack_rnn_input = pack_padded_sequence(
rnn_input, rnn_len.cpu(), batch_first=True
)
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
if isinstance(rnn, (torch.nn.LSTM, torch.nn.GRU)):
rnn.flatten_parameters()
if prev_states is not None and rnn.bidirectional:
prev_states = reset_backward_rnn_state(prev_states)
pack_rnn_output, states = rnn(
pack_rnn_input, hx=None if prev_states is None else prev_states[layer]
)
current_states.append(states)
pad_rnn_output, rnn_len = pad_packed_sequence(
pack_rnn_output, batch_first=True
)
sub = self.subsample[layer + 1]
if sub > 1:
pad_rnn_output = pad_rnn_output[:, ::sub]
rnn_len = torch.tensor([int(i + 1) // sub for i in rnn_len])
projection_layer = getattr(self, "bt%d" % layer)
proj_rnn_output = projection_layer(
pad_rnn_output.contiguous().view(-1, pad_rnn_output.size(2))
)
rnn_output = proj_rnn_output.view(
pad_rnn_output.size(0), pad_rnn_output.size(1), -1
)
if layer in self.aux_output_layers:
aux_rnn_outputs.append(rnn_output)
aux_rnn_lens.append(rnn_len)
if layer < self.elayers - 1:
rnn_output = torch.tanh(self.dropout(rnn_output))
rnn_input = rnn_output
if aux_rnn_outputs:
return (
(rnn_output, aux_rnn_outputs),
(rnn_len, aux_rnn_lens),
current_states,
)
else:
return rnn_output, rnn_len, current_states
class RNN(torch.nn.Module):
"""RNN module.
Args:
idim: Input dimension.
rnn_type: RNN units type.
elayers: Number of RNN layers.
eunits: Number of units ((2 * eunits) if bidirectional)
eprojs: Number of final projection units.
dropout_rate: Dropout rate for RNN layers.
aux_output_layers: List of layer IDs for auxiliary RNN output sequences.
"""
def __init__(
self,
idim: int,
rnn_type: str,
elayers: int,
eunits: int,
eprojs: int,
dropout_rate: float,
aux_output_layers: List = [],
):
"""Initialize RNN module."""
super().__init__()
bidir = rnn_type[0] == "b"
for i in range(elayers):
if i == 0:
input_dim = idim
else:
input_dim = eunits
rnn_layer = torch.nn.LSTM if "lstm" in rnn_type else torch.nn.GRU
rnn = rnn_layer(
input_dim, eunits, num_layers=1, bidirectional=bidir, batch_first=True
)
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.elayers = elayers
self.eunits = eunits
self.eprojs = eprojs
self.rnn_type = rnn_type
self.bidir = bidir
self.l_last = torch.nn.Linear(eunits, eprojs)
self.aux_output_layers = aux_output_layers
def forward(
self,
rnn_input: torch.Tensor,
rnn_len: torch.Tensor,
prev_states: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
"""RNN forward.
Args:
rnn_input: RNN input sequences. (B, T, D_in)
rnn_len: RNN input sequences lengths. (B,)
prev_states: RNN hidden states. [N x (B, T, D_proj)]
Returns:
rnn_output : RNN output sequences. (B, T, D_proj)
with or without intermediate RNN output sequences.
((B, T, D_proj), [N x (B, T, D_proj)])
rnn_len: RNN output sequences lengths. (B,)
current_states: RNN hidden states. [N x (B, T, D_proj)]
"""
aux_rnn_outputs = []
aux_rnn_lens = []
current_states = []
for layer in range(self.elayers):
if not isinstance(rnn_len, torch.Tensor):
rnn_len = torch.tensor(rnn_len)
pack_rnn_input = pack_padded_sequence(
rnn_input, rnn_len.cpu(), batch_first=True
)
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
if isinstance(rnn, (torch.nn.LSTM, torch.nn.GRU)):
rnn.flatten_parameters()
if prev_states is not None and rnn.bidirectional:
prev_states = reset_backward_rnn_state(prev_states)
pack_rnn_output, states = rnn(
pack_rnn_input, hx=None if prev_states is None else prev_states[layer]
)
current_states.append(states)
rnn_output, rnn_len = pad_packed_sequence(pack_rnn_output, batch_first=True)
if self.bidir:
rnn_output = (
rnn_output[:, :, : self.eunits] + rnn_output[:, :, self.eunits :]
)
if layer in self.aux_output_layers:
aux_proj_rnn_output = torch.tanh(
self.l_last(rnn_output.contiguous().view(-1, rnn_output.size(2)))
)
aux_rnn_output = aux_proj_rnn_output.view(
rnn_output.size(0), rnn_output.size(1), -1
)
aux_rnn_outputs.append(aux_rnn_output)
aux_rnn_lens.append(rnn_len)
if layer < self.elayers - 1:
rnn_input = self.dropout(rnn_output)
proj_rnn_output = torch.tanh(
self.l_last(rnn_output.contiguous().view(-1, rnn_output.size(2)))
)
rnn_output = proj_rnn_output.view(rnn_output.size(0), rnn_output.size(1), -1)
if aux_rnn_outputs:
return (
(rnn_output, aux_rnn_outputs),
(rnn_len, aux_rnn_lens),
current_states,
)
else:
return rnn_output, rnn_len, current_states
def reset_backward_rnn_state(
states: Union[torch.Tensor, List[Optional[torch.Tensor]]]
) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]:
"""Set backward BRNN states to zeroes.
Args:
states: Encoder hidden states.
Returns:
states: Encoder hidden states with backward set to zero.
"""
if isinstance(states, list):
for state in states:
state[1::2] = 0.0
else:
states[1::2] = 0.0
return states
class VGG2L(torch.nn.Module):
"""VGG-like module.
Args:
in_channel: number of input channels
"""
def __init__(self, in_channel: int = 1):
"""Initialize VGG-like module."""
super(VGG2L, self).__init__()
# CNN layer (VGG motivated)
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.in_channel = in_channel
def forward(
self, feats: torch.Tensor, feats_len: torch.Tensor, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""VGG2L forward.
Args:
feats: Feature sequences. (B, F, D_feats)
feats_len: Feature sequences lengths. (B, )
Returns:
vgg_out: VGG2L output sequences. (B, F // 4, 128 * D_feats // 4)
vgg_out_len: VGG2L output sequences lengths. (B,)
"""
feats = feats.view(
feats.size(0),
feats.size(1),
self.in_channel,
feats.size(2) // self.in_channel,
).transpose(1, 2)
vgg1 = F.relu(self.conv1_1(feats))
vgg1 = F.relu(self.conv1_2(vgg1))
vgg1 = F.max_pool2d(vgg1, 2, stride=2, ceil_mode=True)
vgg2 = F.relu(self.conv2_1(vgg1))
vgg2 = F.relu(self.conv2_2(vgg2))
vgg2 = F.max_pool2d(vgg2, 2, stride=2, ceil_mode=True)
vgg_out = vgg2.transpose(1, 2)
vgg_out = vgg_out.contiguous().view(
vgg_out.size(0), vgg_out.size(1), vgg_out.size(2) * vgg_out.size(3)
)
if torch.is_tensor(feats_len):
feats_len = feats_len.cpu().numpy()
else:
feats_len = np.array(feats_len, dtype=np.float32)
vgg1_len = np.array(np.ceil(feats_len / 2), dtype=np.int64)
vgg_out_len = np.array(
np.ceil(np.array(vgg1_len, dtype=np.float32) / 2), dtype=np.int64
).tolist()
return vgg_out, vgg_out_len, None
class Encoder(torch.nn.Module):
"""Encoder module.
Args:
idim: Input dimension.
etype: Encoder units type.
elayers: Number of encoder layers.
eunits: Number of encoder units per layer.
eprojs: Number of projection units per layer.
subsample: Subsampling rate per layer.
dropout_rate: Dropout rate for encoder layers.
intermediate_encoder_layers: Layer IDs for auxiliary encoder output sequences.
"""
def __init__(
self,
idim: int,
etype: str,
elayers: int,
eunits: int,
eprojs: int,
subsample: np.ndarray,
dropout_rate: float = 0.0,
aux_enc_output_layers: List = [],
):
"""Initialize Encoder module."""
super(Encoder, self).__init__()
rnn_type = etype.lstrip("vgg").rstrip("p")
in_channel = 1
if etype.startswith("vgg"):
if etype[-1] == "p":
self.enc = torch.nn.ModuleList(
[
VGG2L(in_channel),
RNNP(
get_vgg2l_odim(idim, in_channel=in_channel),
rnn_type,
elayers,
eunits,
eprojs,
subsample,
dropout_rate=dropout_rate,
aux_output_layers=aux_enc_output_layers,
),
]
)
else:
self.enc = torch.nn.ModuleList(
[
VGG2L(in_channel),
RNN(
get_vgg2l_odim(idim, in_channel=in_channel),
rnn_type,
elayers,
eunits,
eprojs,
dropout_rate=dropout_rate,
aux_output_layers=aux_enc_output_layers,
),
]
)
self.conv_subsampling_factor = 4
else:
if etype[-1] == "p":
self.enc = torch.nn.ModuleList(
[
RNNP(
idim,
rnn_type,
elayers,
eunits,
eprojs,
subsample,
dropout_rate=dropout_rate,
aux_output_layers=aux_enc_output_layers,
)
]
)
else:
self.enc = torch.nn.ModuleList(
[
RNN(
idim,
rnn_type,
elayers,
eunits,
eprojs,
dropout_rate=dropout_rate,
aux_output_layers=aux_enc_output_layers,
)
]
)
self.conv_subsampling_factor = 1
def forward(
self,
feats: torch.Tensor,
feats_len: torch.Tensor,
prev_states: Optional[List[torch.Tensor]] = None,
):
"""Forward encoder.
Args:
feats: Feature sequences. (B, F, D_feats)
feats_len: Feature sequences lengths. (B,)
prev_states: Previous encoder hidden states. [N x (B, T, D_enc)]
Returns:
enc_out: Encoder output sequences. (B, T, D_enc)
with or without encoder intermediate output sequences.
((B, T, D_enc), [N x (B, T, D_enc)])
enc_out_len: Encoder output sequences lengths. (B,)
current_states: Encoder hidden states. [N x (B, T, D_enc)]
"""
if prev_states is None:
prev_states = [None] * len(self.enc)
assert len(prev_states) == len(self.enc)
_enc_out = feats
_enc_out_len = feats_len
current_states = []
for rnn_module, prev_state in zip(self.enc, prev_states):
_enc_out, _enc_out_len, states = rnn_module(
_enc_out,
_enc_out_len,
prev_states=prev_state,
)
current_states.append(states)
if isinstance(_enc_out, tuple):
enc_out, aux_enc_out = _enc_out[0], _enc_out[1]
enc_out_len, aux_enc_out_len = _enc_out_len[0], _enc_out_len[1]
enc_out_mask = to_device(enc_out, make_pad_mask(enc_out_len).unsqueeze(-1))
enc_out = enc_out.masked_fill(enc_out_mask, 0.0)
for i in range(len(aux_enc_out)):
aux_mask = to_device(
aux_enc_out[i], make_pad_mask(aux_enc_out_len[i]).unsqueeze(-1)
)
aux_enc_out[i] = aux_enc_out[i].masked_fill(aux_mask, 0.0)
return (
(enc_out, aux_enc_out),
(enc_out_len, aux_enc_out_len),
current_states,
)
else:
enc_out_mask = to_device(
_enc_out, make_pad_mask(_enc_out_len).unsqueeze(-1)
)
return _enc_out.masked_fill(enc_out_mask, 0.0), _enc_out_len, current_states
def encoder_for(
args: Namespace,
idim: int,
subsample: np.ndarray,
aux_enc_output_layers: List = [],
) -> torch.nn.Module:
"""Instantiate a RNN encoder with specified arguments.
Args:
args: The model arguments.
idim: Input dimension.
subsample: Subsampling rate per layer.
aux_enc_output_layers: Layer IDs for auxiliary encoder output sequences.
Returns:
: Encoder module.
"""
return Encoder(
idim,
args.etype,
args.elayers,
args.eunits,
args.eprojs,
subsample,
dropout_rate=args.dropout_rate,
aux_enc_output_layers=aux_enc_output_layers,
)
"""Module implementing Transducer main and auxiliary tasks."""
from typing import Any, List, Optional, Tuple
import torch
from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.pytorch_backend.transducer.joint_network import JointNetwork
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( # noqa: H301
LabelSmoothingLoss,
)
class TransducerTasks(torch.nn.Module):
"""Transducer tasks module."""
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joint_dim: int,
output_dim: int,
joint_activation_type: str = "tanh",
transducer_loss_weight: float = 1.0,
ctc_loss: bool = False,
ctc_loss_weight: float = 0.5,
ctc_loss_dropout_rate: float = 0.0,
lm_loss: bool = False,
lm_loss_weight: float = 0.5,
lm_loss_smoothing_rate: float = 0.0,
aux_transducer_loss: bool = False,
aux_transducer_loss_weight: float = 0.2,
aux_transducer_loss_mlp_dim: int = 320,
aux_trans_loss_mlp_dropout_rate: float = 0.0,
symm_kl_div_loss: bool = False,
symm_kl_div_loss_weight: float = 0.2,
fastemit_lambda: float = 0.0,
blank_id: int = 0,
ignore_id: int = -1,
training: bool = False,
):
"""Initialize module for Transducer tasks.
Args:
encoder_dim: Encoder outputs dimension.
decoder_dim: Decoder outputs dimension.
joint_dim: Joint space dimension.
output_dim: Output dimension.
joint_activation_type: Type of activation for joint network.
transducer_loss_weight: Weight for main transducer loss.
ctc_loss: Compute CTC loss.
ctc_loss_weight: Weight of CTC loss.
ctc_loss_dropout_rate: Dropout rate for CTC loss inputs.
lm_loss: Compute LM loss.
lm_loss_weight: Weight of LM loss.
lm_loss_smoothing_rate: Smoothing rate for LM loss' label smoothing.
aux_transducer_loss: Compute auxiliary transducer loss.
aux_transducer_loss_weight: Weight of auxiliary transducer loss.
aux_transducer_loss_mlp_dim: Hidden dimension for aux. transducer MLP.
aux_trans_loss_mlp_dropout_rate: Dropout rate for aux. transducer MLP.
symm_kl_div_loss: Compute KL divergence loss.
symm_kl_div_loss_weight: Weight of KL divergence loss.
fastemit_lambda: Regularization parameter for FastEmit.
blank_id: Blank symbol ID.
ignore_id: Padding symbol ID.
training: Whether the model was initializated in training or inference mode.
"""
super().__init__()
if not training:
ctc_loss, lm_loss, aux_transducer_loss, symm_kl_div_loss = (
False,
False,
False,
False,
)
self.joint_network = JointNetwork(
output_dim, encoder_dim, decoder_dim, joint_dim, joint_activation_type
)
if training:
from warprnnt_pytorch import RNNTLoss
self.transducer_loss = RNNTLoss(
blank=blank_id,
reduction="sum",
fastemit_lambda=fastemit_lambda,
)
if ctc_loss:
self.ctc_lin = torch.nn.Linear(encoder_dim, output_dim)
self.ctc_loss = torch.nn.CTCLoss(
blank=blank_id,
reduction="none",
zero_infinity=True,
)
if aux_transducer_loss:
self.mlp = torch.nn.Sequential(
torch.nn.Linear(encoder_dim, aux_transducer_loss_mlp_dim),
torch.nn.LayerNorm(aux_transducer_loss_mlp_dim),
torch.nn.Dropout(p=aux_trans_loss_mlp_dropout_rate),
torch.nn.ReLU(),
torch.nn.Linear(aux_transducer_loss_mlp_dim, joint_dim),
)
if symm_kl_div_loss:
self.kl_div = torch.nn.KLDivLoss(reduction="sum")
if lm_loss:
self.lm_lin = torch.nn.Linear(decoder_dim, output_dim)
self.label_smoothing_loss = LabelSmoothingLoss(
output_dim, ignore_id, lm_loss_smoothing_rate, normalize_length=False
)
self.output_dim = output_dim
self.transducer_loss_weight = transducer_loss_weight
self.use_ctc_loss = ctc_loss
self.ctc_loss_weight = ctc_loss_weight
self.ctc_dropout_rate = ctc_loss_dropout_rate
self.use_lm_loss = lm_loss
self.lm_loss_weight = lm_loss_weight
self.use_aux_transducer_loss = aux_transducer_loss
self.aux_transducer_loss_weight = aux_transducer_loss_weight
self.use_symm_kl_div_loss = symm_kl_div_loss
self.symm_kl_div_loss_weight = symm_kl_div_loss_weight
self.blank_id = blank_id
self.ignore_id = ignore_id
self.target = None
def compute_transducer_loss(
self,
enc_out: torch.Tensor,
dec_out: torch.tensor,
target: torch.Tensor,
t_len: torch.Tensor,
u_len: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
enc_out: Encoder output sequences. (B, T, D_enc)
dec_out: Decoder output sequences. (B, U, D_dec)
target: Target label ID sequences. (B, L)
t_len: Time lengths. (B,)
u_len: Label lengths. (B,)
Returns:
(joint_out, loss_trans):
Joint output sequences. (B, T, U, D_joint),
Transducer loss value.
"""
joint_out = self.joint_network(enc_out.unsqueeze(2), dec_out.unsqueeze(1))
loss_trans = self.transducer_loss(joint_out, target, t_len, u_len)
loss_trans /= joint_out.size(0)
return joint_out, loss_trans
def compute_ctc_loss(
self,
enc_out: torch.Tensor,
target: torch.Tensor,
t_len: torch.Tensor,
u_len: torch.Tensor,
):
"""Compute CTC loss.
Args:
enc_out: Encoder output sequences. (B, T, D_enc)
target: Target character ID sequences. (B, U)
t_len: Time lengths. (B,)
u_len: Label lengths. (B,)
Returns:
: CTC loss value.
"""
ctc_lin = self.ctc_lin(
torch.nn.functional.dropout(
enc_out.to(dtype=torch.float32), p=self.ctc_dropout_rate
)
)
ctc_logp = torch.log_softmax(ctc_lin.transpose(0, 1), dim=-1)
with torch.backends.cudnn.flags(deterministic=True):
loss_ctc = self.ctc_loss(ctc_logp, target, t_len, u_len)
return loss_ctc.mean()
def compute_aux_transducer_and_symm_kl_div_losses(
self,
aux_enc_out: torch.Tensor,
dec_out: torch.Tensor,
joint_out: torch.Tensor,
target: torch.Tensor,
aux_t_len: torch.Tensor,
u_len: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute auxiliary Transducer loss and Jensen-Shannon divergence loss.
Args:
aux_enc_out: Encoder auxiliary output sequences. [N x (B, T_aux, D_enc_aux)]
dec_out: Decoder output sequences. (B, U, D_dec)
joint_out: Joint output sequences. (B, T, U, D_joint)
target: Target character ID sequences. (B, L)
aux_t_len: Auxiliary time lengths. [N x (B,)]
u_len: True U lengths. (B,)
Returns:
: Auxiliary Transducer loss and KL divergence loss values.
"""
aux_trans_loss = 0
symm_kl_div_loss = 0
num_aux_layers = len(aux_enc_out)
B, T, U, D = joint_out.shape
for p in self.joint_network.parameters():
p.requires_grad = False
for i, aux_enc_out_i in enumerate(aux_enc_out):
aux_mlp = self.mlp(aux_enc_out_i)
aux_joint_out = self.joint_network(
aux_mlp.unsqueeze(2),
dec_out.unsqueeze(1),
is_aux=True,
)
if self.use_aux_transducer_loss:
aux_trans_loss += (
self.transducer_loss(
aux_joint_out,
target,
aux_t_len[i],
u_len,
)
/ B
)
if self.use_symm_kl_div_loss:
denom = B * T * U
kl_main_aux = (
self.kl_div(
torch.log_softmax(joint_out, dim=-1),
torch.softmax(aux_joint_out, dim=-1),
)
/ denom
)
kl_aux_main = (
self.kl_div(
torch.log_softmax(aux_joint_out, dim=-1),
torch.softmax(joint_out, dim=-1),
)
/ denom
)
symm_kl_div_loss += kl_main_aux + kl_aux_main
for p in self.joint_network.parameters():
p.requires_grad = True
aux_trans_loss /= num_aux_layers
if self.use_symm_kl_div_loss:
symm_kl_div_loss /= num_aux_layers
return aux_trans_loss, symm_kl_div_loss
def compute_lm_loss(
self,
dec_out: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""Forward LM loss.
Args:
dec_out: Decoder output sequences. (B, U, D_dec)
target: Target label ID sequences. (B, U)
Returns:
: LM loss value.
"""
lm_lin = self.lm_lin(dec_out)
lm_loss = self.label_smoothing_loss(lm_lin, target)
return lm_loss
def set_target(self, target: torch.Tensor):
"""Set target label ID sequences.
Args:
target: Target label ID sequences. (B, L)
"""
self.target = target
def get_target(self):
"""Set target label ID sequences.
Args:
Returns:
target: Target label ID sequences. (B, L)
"""
return self.target
def get_transducer_tasks_io(
self,
labels: torch.Tensor,
enc_out_len: torch.Tensor,
aux_enc_out_len: Optional[List],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get Transducer tasks inputs and outputs.
Args:
labels: Label ID sequences. (B, U)
enc_out_len: Time lengths. (B,)
aux_enc_out_len: Auxiliary time lengths. [N X (B,)]
Returns:
target: Target label ID sequences. (B, L)
lm_loss_target: LM loss target label ID sequences. (B, U)
t_len: Time lengths. (B,)
aux_t_len: Auxiliary time lengths. [N x (B,)]
u_len: Label lengths. (B,)
"""
device = labels.device
labels_unpad = [label[label != self.ignore_id] for label in labels]
blank = labels[0].new([self.blank_id])
target = pad_list(labels_unpad, self.blank_id).type(torch.int32).to(device)
lm_loss_target = (
pad_list(
[torch.cat([y, blank], dim=0) for y in labels_unpad], self.ignore_id
)
.type(torch.int64)
.to(device)
)
self.set_target(target)
if enc_out_len.dim() > 1:
enc_mask_unpad = [m[m != 0] for m in enc_out_len]
enc_out_len = list(map(int, [m.size(0) for m in enc_mask_unpad]))
else:
enc_out_len = list(map(int, enc_out_len))
t_len = torch.IntTensor(enc_out_len).to(device)
u_len = torch.IntTensor([label.size(0) for label in labels_unpad]).to(device)
if aux_enc_out_len:
aux_t_len = []
for i in range(len(aux_enc_out_len)):
if aux_enc_out_len[i].dim() > 1:
aux_mask_unpad = [aux[aux != 0] for aux in aux_enc_out_len[i]]
aux_t_len.append(
torch.IntTensor(
list(map(int, [aux.size(0) for aux in aux_mask_unpad]))
).to(device)
)
else:
aux_t_len.append(
torch.IntTensor(list(map(int, aux_enc_out_len[i]))).to(device)
)
else:
aux_t_len = aux_enc_out_len
return target, lm_loss_target, t_len, aux_t_len, u_len
def forward(
self,
enc_out: torch.Tensor,
aux_enc_out: List[torch.Tensor],
dec_out: torch.Tensor,
labels: torch.Tensor,
enc_out_len: torch.Tensor,
aux_enc_out_len: torch.Tensor,
) -> Tuple[Tuple[Any], float, float]:
"""Forward main and auxiliary task.
Args:
enc_out: Encoder output sequences. (B, T, D_enc)
aux_enc_out: Encoder intermediate output sequences. (B, T_aux, D_enc_aux)
dec_out: Decoder output sequences. (B, U, D_dec)
target: Target label ID sequences. (B, L)
t_len: Time lengths. (B,)
aux_t_len: Auxiliary time lengths. (B,)
u_len: Label lengths. (B,)
Returns:
: Weighted losses.
(transducer loss, ctc loss, aux Transducer loss, KL div loss, LM loss)
cer: Sentence-level CER score.
wer: Sentence-level WER score.
"""
if self.use_symm_kl_div_loss:
assert self.use_aux_transducer_loss
(trans_loss, ctc_loss, lm_loss, aux_trans_loss, symm_kl_div_loss) = (
0.0,
0.0,
0.0,
0.0,
0.0,
)
target, lm_loss_target, t_len, aux_t_len, u_len = self.get_transducer_tasks_io(
labels,
enc_out_len,
aux_enc_out_len,
)
joint_out, trans_loss = self.compute_transducer_loss(
enc_out, dec_out, target, t_len, u_len
)
if self.use_ctc_loss:
ctc_loss = self.compute_ctc_loss(enc_out, target, t_len, u_len)
if self.use_aux_transducer_loss:
(
aux_trans_loss,
symm_kl_div_loss,
) = self.compute_aux_transducer_and_symm_kl_div_losses(
aux_enc_out,
dec_out,
joint_out,
target,
aux_t_len,
u_len,
)
if self.use_lm_loss:
lm_loss = self.compute_lm_loss(dec_out, lm_loss_target)
return (
self.transducer_loss_weight * trans_loss,
self.ctc_loss_weight * ctc_loss,
self.aux_transducer_loss_weight * aux_trans_loss,
self.symm_kl_div_loss_weight * symm_kl_div_loss,
self.lm_loss_weight * lm_loss,
)
"""Transformer decoder layer definition for custom Transducer model."""
from typing import Optional
import torch
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
class TransformerDecoderLayer(torch.nn.Module):
"""Transformer decoder layer module for custom Transducer model.
Args:
hdim: Hidden dimension.
self_attention: Self-attention module.
feed_forward: Feed forward module.
dropout_rate: Dropout rate.
"""
def __init__(
self,
hdim: int,
self_attention: MultiHeadedAttention,
feed_forward: PositionwiseFeedForward,
dropout_rate: float,
):
"""Construct an DecoderLayer object."""
super().__init__()
self.self_attention = self_attention
self.feed_forward = feed_forward
self.norm1 = LayerNorm(hdim)
self.norm2 = LayerNorm(hdim)
self.dropout = torch.nn.Dropout(dropout_rate)
self.hdim = hdim
def forward(
self,
sequence: torch.Tensor,
mask: torch.Tensor,
cache: Optional[torch.Tensor] = None,
):
"""Compute previous decoder output sequences.
Args:
sequence: Transformer input sequences. (B, U, D_dec)
mask: Transformer intput mask sequences. (B, U)
cache: Cached decoder output sequences. (B, (U - 1), D_dec)
Returns:
sequence: Transformer output sequences. (B, U, D_dec)
mask: Transformer output mask sequences. (B, U)
"""
residual = sequence
sequence = self.norm1(sequence)
if cache is None:
sequence_q = sequence
else:
batch = sequence.shape[0]
prev_len = sequence.shape[1] - 1
assert cache.shape == (
batch,
prev_len,
self.hdim,
), f"{cache.shape} == {(batch, prev_len, self.hdim)}"
sequence_q = sequence[:, -1:, :]
residual = residual[:, -1:, :]
if mask is not None:
mask = mask[:, -1:, :]
sequence = residual + self.dropout(
self.self_attention(sequence_q, sequence, sequence, mask)
)
residual = sequence
sequence = self.norm2(sequence)
sequence = residual + self.dropout(self.feed_forward(sequence))
if cache is not None:
sequence = torch.cat([cache, sequence], dim=1)
return sequence, mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment