Commit 60a2c57a authored by sunzhq2's avatar sunzhq2 Committed by xuxo
Browse files

update conformer

parent 4a699441
"""DNN beamformer module."""
from typing import Tuple
import torch
from torch.nn import functional as F
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.frontends.beamformer import ( # noqa: H301
apply_beamforming_vector,
get_mvdr_vector,
get_power_spectral_density_matrix,
)
from espnet.nets.pytorch_backend.frontends.mask_estimator import MaskEstimator
class DNN_Beamformer(torch.nn.Module):
"""DNN mask based Beamformer
Citation:
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
https://arxiv.org/abs/1703.04783
"""
def __init__(
self,
bidim,
btype="blstmp",
blayers=3,
bunits=300,
bprojs=320,
bnmask=2,
dropout_rate=0.0,
badim=320,
ref_channel: int = -1,
beamformer_type="mvdr",
):
super().__init__()
self.mask = MaskEstimator(
btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask
)
self.ref = AttentionReference(bidim, badim)
self.ref_channel = ref_channel
self.nmask = bnmask
if beamformer_type != "mvdr":
raise ValueError(
"Not supporting beamformer_type={}".format(beamformer_type)
)
self.beamformer_type = beamformer_type
def forward(
self, data: ComplexTensor, ilens: torch.LongTensor
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
"""The forward function
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq
Args:
data (ComplexTensor): (B, T, C, F)
ilens (torch.Tensor): (B,)
Returns:
enhanced (ComplexTensor): (B, T, F)
ilens (torch.Tensor): (B,)
"""
def apply_beamforming(data, ilens, psd_speech, psd_noise):
# u: (B, C)
if self.ref_channel < 0:
u, _ = self.ref(psd_speech, ilens)
else:
# (optional) Create onehot vector for fixed reference microphone
u = torch.zeros(
*(data.size()[:-3] + (data.size(-2),)), device=data.device
)
u[..., self.ref_channel].fill_(1)
ws = get_mvdr_vector(psd_speech, psd_noise, u)
enhanced = apply_beamforming_vector(ws, data)
return enhanced, ws
# data (B, T, C, F) -> (B, F, C, T)
data = data.permute(0, 3, 2, 1)
# mask: (B, F, C, T)
masks, _ = self.mask(data, ilens)
assert self.nmask == len(masks)
if self.nmask == 2: # (mask_speech, mask_noise)
mask_speech, mask_noise = masks
psd_speech = get_power_spectral_density_matrix(data, mask_speech)
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise)
# (..., F, T) -> (..., T, F)
enhanced = enhanced.transpose(-1, -2)
mask_speech = mask_speech.transpose(-1, -3)
else: # multi-speaker case: (mask_speech1, ..., mask_noise)
mask_speech = list(masks[:-1])
mask_noise = masks[-1]
psd_speeches = [
get_power_spectral_density_matrix(data, mask) for mask in mask_speech
]
psd_noise = get_power_spectral_density_matrix(data, mask_noise)
enhanced = []
ws = []
for i in range(self.nmask - 1):
psd_speech = psd_speeches.pop(i)
# treat all other speakers' psd_speech as noises
enh, w = apply_beamforming(
data, ilens, psd_speech, sum(psd_speeches) + psd_noise
)
psd_speeches.insert(i, psd_speech)
# (..., F, T) -> (..., T, F)
enh = enh.transpose(-1, -2)
mask_speech[i] = mask_speech[i].transpose(-1, -3)
enhanced.append(enh)
ws.append(w)
return enhanced, ilens, mask_speech
class AttentionReference(torch.nn.Module):
def __init__(self, bidim, att_dim):
super().__init__()
self.mlp_psd = torch.nn.Linear(bidim, att_dim)
self.gvec = torch.nn.Linear(att_dim, 1)
def forward(
self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""The forward function
Args:
psd_in (ComplexTensor): (B, F, C, C)
ilens (torch.Tensor): (B,)
scaling (float):
Returns:
u (torch.Tensor): (B, C)
ilens (torch.Tensor): (B,)
"""
B, _, C = psd_in.size()[:3]
assert psd_in.size(2) == psd_in.size(3), psd_in.size()
# psd_in: (B, F, C, C)
psd = psd_in.masked_fill(
torch.eye(C, dtype=torch.bool, device=psd_in.device), 0
)
# psd: (B, F, C, C) -> (B, C, F)
psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)
# Calculate amplitude
psd_feat = (psd.real**2 + psd.imag**2) ** 0.5
# (B, C, F) -> (B, C, F2)
mlp_psd = self.mlp_psd(psd_feat)
# (B, C, F2) -> (B, C, 1) -> (B, C)
e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
u = F.softmax(scaling * e, dim=-1)
return u, ilens
from typing import Tuple
import torch
from pytorch_wpe import wpe_one_iteration
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.frontends.mask_estimator import MaskEstimator
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
class DNN_WPE(torch.nn.Module):
def __init__(
self,
wtype: str = "blstmp",
widim: int = 257,
wlayers: int = 3,
wunits: int = 300,
wprojs: int = 320,
dropout_rate: float = 0.0,
taps: int = 5,
delay: int = 3,
use_dnn_mask: bool = True,
iterations: int = 1,
normalization: bool = False,
):
super().__init__()
self.iterations = iterations
self.taps = taps
self.delay = delay
self.normalization = normalization
self.use_dnn_mask = use_dnn_mask
self.inverse_power = True
if self.use_dnn_mask:
self.mask_est = MaskEstimator(
wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1
)
def forward(
self, data: ComplexTensor, ilens: torch.LongTensor
) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]:
"""The forward function
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq or Some dimension of the feature vector
Args:
data: (B, C, T, F)
ilens: (B,)
Returns:
data: (B, C, T, F)
ilens: (B,)
"""
# (B, T, C, F) -> (B, F, C, T)
enhanced = data = data.permute(0, 3, 2, 1)
mask = None
for i in range(self.iterations):
# Calculate power: (..., C, T)
power = enhanced.real**2 + enhanced.imag**2
if i == 0 and self.use_dnn_mask:
# mask: (B, F, C, T)
(mask,), _ = self.mask_est(enhanced, ilens)
if self.normalization:
# Normalize along T
mask = mask / mask.sum(dim=-1)[..., None]
# (..., C, T) * (..., C, T) -> (..., C, T)
power = power * mask
# Averaging along the channel axis: (..., C, T) -> (..., T)
power = power.mean(dim=-2)
# enhanced: (..., C, T) -> (..., C, T)
enhanced = wpe_one_iteration(
data.contiguous(),
power,
taps=self.taps,
delay=self.delay,
inverse_power=self.inverse_power,
)
enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0)
# (B, F, C, T) -> (B, T, C, F)
enhanced = enhanced.permute(0, 3, 2, 1)
if mask is not None:
mask = mask.transpose(-1, -3)
return enhanced, ilens, mask
from typing import List, Tuple, Union
import librosa
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
class FeatureTransform(torch.nn.Module):
def __init__(
self,
# Mel options,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = 0.0,
fmax: float = None,
# Normalization
stats_file: str = None,
apply_uttmvn: bool = True,
uttmvn_norm_means: bool = True,
uttmvn_norm_vars: bool = False,
):
super().__init__()
self.apply_uttmvn = apply_uttmvn
self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
self.stats_file = stats_file
if stats_file is not None:
self.global_mvn = GlobalMVN(stats_file)
else:
self.global_mvn = None
if self.apply_uttmvn is not None:
self.uttmvn = UtteranceMVN(
norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars
)
else:
self.uttmvn = None
def forward(
self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]]
) -> Tuple[torch.Tensor, torch.LongTensor]:
# (B, T, F) or (B, T, C, F)
if x.dim() not in (3, 4):
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
if not torch.is_tensor(ilens):
ilens = torch.from_numpy(np.asarray(ilens)).to(x.device)
if x.dim() == 4:
# h: (B, T, C, F) -> h: (B, T, F)
if self.training:
# Select 1ch randomly
ch = np.random.randint(x.size(2))
h = x[:, :, ch, :]
else:
# Use the first channel
h = x[:, :, 0, :]
else:
h = x
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
h = h.real**2 + h.imag**2
h, _ = self.logmel(h, ilens)
if self.stats_file is not None:
h, _ = self.global_mvn(h, ilens)
if self.apply_uttmvn:
h, _ = self.uttmvn(h, ilens)
return h, ilens
class LogMel(torch.nn.Module):
"""Convert STFT to fbank feats
The arguments is same as librosa.filters.mel
Args:
fs: number > 0 [scalar] sampling rate of the incoming signal
n_fft: int > 0 [scalar] number of FFT components
n_mels: int > 0 [scalar] number of Mel bands to generate
fmin: float >= 0 [scalar] lowest frequency (in Hz)
fmax: float >= 0 [scalar] highest frequency (in Hz).
If `None`, use `fmax = fs / 2.0`
htk: use HTK formula instead of Slaney
norm: {None, 1, np.inf} [scalar]
if 1, divide the triangular mel weights by the width of the mel band
(area normalization). Otherwise, leave all the triangles aiming for
a peak value of 1.0
"""
def __init__(
self,
fs: int = 16000,
n_fft: int = 512,
n_mels: int = 80,
fmin: float = 0.0,
fmax: float = None,
htk: bool = False,
norm=1,
):
super().__init__()
_mel_options = dict(
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
)
self.mel_options = _mel_options
# Note(kamo): The mel matrix of librosa is different from kaldi.
melmat = librosa.filters.mel(**_mel_options)
# melmat: (D2, D1) -> (D1, D2)
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
def extra_repr(self):
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
def forward(
self, feat: torch.Tensor, ilens: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
mel_feat = torch.matmul(feat, self.melmat)
logmel_feat = (mel_feat + 1e-20).log()
# Zero padding
logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0)
return logmel_feat, ilens
class GlobalMVN(torch.nn.Module):
"""Apply global mean and variance normalization
Args:
stats_file(str): npy file of 1-dim array or text file.
From the _first element to
the {(len(array) - 1) / 2}th element are treated as
the sum of features,
and the rest excluding the last elements are
treated as the sum of the square value of features,
and the last elements eqauls to the number of samples.
std_floor(float):
"""
def __init__(
self,
stats_file: str,
norm_means: bool = True,
norm_vars: bool = True,
eps: float = 1.0e-20,
):
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.stats_file = stats_file
stats = np.load(stats_file)
stats = stats.astype(float)
assert (len(stats) - 1) % 2 == 0, stats.shape
count = stats.flatten()[-1]
mean = stats[: (len(stats) - 1) // 2] / count
var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean
std = np.maximum(np.sqrt(var), eps)
self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32)))
self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32)))
def extra_repr(self):
return (
f"stats_file={self.stats_file}, "
f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
)
def forward(
self, x: torch.Tensor, ilens: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
# feat: (B, T, D)
if self.norm_means:
x += self.bias.type_as(x)
x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
if self.norm_vars:
x *= self.scale.type_as(x)
return x, ilens
class UtteranceMVN(torch.nn.Module):
def __init__(
self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20
):
super().__init__()
self.norm_means = norm_means
self.norm_vars = norm_vars
self.eps = eps
def extra_repr(self):
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
def forward(
self, x: torch.Tensor, ilens: torch.LongTensor
) -> Tuple[torch.Tensor, torch.LongTensor]:
return utterance_mvn(
x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps
)
def utterance_mvn(
x: torch.Tensor,
ilens: torch.LongTensor,
norm_means: bool = True,
norm_vars: bool = False,
eps: float = 1.0e-20,
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""Apply utterance mean and variance normalization
Args:
x: (B, T, D), assumed zero padded
ilens: (B, T, D)
norm_means:
norm_vars:
eps:
"""
ilens_ = ilens.type_as(x)
# mean: (B, D)
mean = x.sum(dim=1) / ilens_[:, None]
if norm_means:
x -= mean[:, None, :]
x_ = x
else:
x_ = x - mean[:, None, :]
# Zero padding
x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0)
if norm_vars:
var = x_.pow(2).sum(dim=1) / ilens_[:, None]
var = torch.clamp(var, min=eps)
x /= var.sqrt()[:, None, :]
x_ = x
return x_, ilens
def feature_transform_for(args, n_fft):
return FeatureTransform(
# Mel options,
fs=args.fbank_fs,
n_fft=n_fft,
n_mels=args.n_mels,
fmin=args.fbank_fmin,
fmax=args.fbank_fmax,
# Normalization
stats_file=args.stats_file,
apply_uttmvn=args.apply_uttmvn,
uttmvn_norm_means=args.uttmvn_norm_means,
uttmvn_norm_vars=args.uttmvn_norm_vars,
)
from typing import List, Optional, Tuple, Union
import numpy
import torch
import torch.nn as nn
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.frontends.dnn_beamformer import DNN_Beamformer
from espnet.nets.pytorch_backend.frontends.dnn_wpe import DNN_WPE
class Frontend(nn.Module):
def __init__(
self,
idim: int,
# WPE options
use_wpe: bool = False,
wtype: str = "blstmp",
wlayers: int = 3,
wunits: int = 300,
wprojs: int = 320,
wdropout_rate: float = 0.0,
taps: int = 5,
delay: int = 3,
use_dnn_mask_for_wpe: bool = True,
# Beamformer options
use_beamformer: bool = False,
btype: str = "blstmp",
blayers: int = 3,
bunits: int = 300,
bprojs: int = 320,
bnmask: int = 2,
badim: int = 320,
ref_channel: int = -1,
bdropout_rate=0.0,
):
super().__init__()
self.use_beamformer = use_beamformer
self.use_wpe = use_wpe
self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe
# use frontend for all the data,
# e.g. in the case of multi-speaker speech separation
self.use_frontend_for_all = bnmask > 2
if self.use_wpe:
if self.use_dnn_mask_for_wpe:
# Use DNN for power estimation
# (Not observed significant gains)
iterations = 1
else:
# Performing as conventional WPE, without DNN Estimator
iterations = 2
self.wpe = DNN_WPE(
wtype=wtype,
widim=idim,
wunits=wunits,
wprojs=wprojs,
wlayers=wlayers,
taps=taps,
delay=delay,
dropout_rate=wdropout_rate,
iterations=iterations,
use_dnn_mask=use_dnn_mask_for_wpe,
)
else:
self.wpe = None
if self.use_beamformer:
self.beamformer = DNN_Beamformer(
btype=btype,
bidim=idim,
bunits=bunits,
bprojs=bprojs,
blayers=blayers,
bnmask=bnmask,
dropout_rate=bdropout_rate,
badim=badim,
ref_channel=ref_channel,
)
else:
self.beamformer = None
def forward(
self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]]
) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]:
assert len(x) == len(ilens), (len(x), len(ilens))
# (B, T, F) or (B, T, C, F)
if x.dim() not in (3, 4):
raise ValueError(f"Input dim must be 3 or 4: {x.dim()}")
if not torch.is_tensor(ilens):
ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device)
mask = None
h = x
if h.dim() == 4:
if self.training:
choices = [(False, False)] if not self.use_frontend_for_all else []
if self.use_wpe:
choices.append((True, False))
if self.use_beamformer:
choices.append((False, True))
use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))]
else:
use_wpe = self.use_wpe
use_beamformer = self.use_beamformer
# 1. WPE
if use_wpe:
# h: (B, T, C, F) -> h: (B, T, C, F)
h, ilens, mask = self.wpe(h, ilens)
# 2. Beamformer
if use_beamformer:
# h: (B, T, C, F) -> h: (B, T, F)
h, ilens, mask = self.beamformer(h, ilens)
return h, ilens, mask
def frontend_for(args, idim):
return Frontend(
idim=idim,
# WPE options
use_wpe=args.use_wpe,
wtype=args.wtype,
wlayers=args.wlayers,
wunits=args.wunits,
wprojs=args.wprojs,
wdropout_rate=args.wdropout_rate,
taps=args.wpe_taps,
delay=args.wpe_delay,
use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe,
# Beamformer options
use_beamformer=args.use_beamformer,
btype=args.btype,
blayers=args.blayers,
bunits=args.bunits,
bprojs=args.bprojs,
bnmask=args.bnmask,
badim=args.badim,
ref_channel=args.ref_channel,
bdropout_rate=args.bdropout_rate,
)
from typing import Tuple
import numpy as np
import torch
from torch.nn import functional as F
from torch_complex.tensor import ComplexTensor
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.pytorch_backend.rnn.encoders import RNN, RNNP
class MaskEstimator(torch.nn.Module):
def __init__(self, type, idim, layers, units, projs, dropout, nmask=1):
super().__init__()
subsample = np.ones(layers + 1, dtype=np.int64)
typ = type.lstrip("vgg").rstrip("p")
if type[-1] == "p":
self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ)
else:
self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ)
self.type = type
self.nmask = nmask
self.linears = torch.nn.ModuleList(
[torch.nn.Linear(projs, idim) for _ in range(nmask)]
)
def forward(
self, xs: ComplexTensor, ilens: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
"""The forward function
Args:
xs: (B, F, C, T)
ilens: (B,)
Returns:
hs (torch.Tensor): The hidden vector (B, F, C, T)
masks: A tuple of the masks. (B, F, C, T)
ilens: (B,)
"""
assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0))
_, _, C, input_length = xs.size()
# (B, F, C, T) -> (B, C, T, F)
xs = xs.permute(0, 2, 3, 1)
# Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
xs = (xs.real**2 + xs.imag**2) ** 0.5
# xs: (B, C, T, F) -> xs: (B * C, T, F)
xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1))
# ilens: (B,) -> ilens_: (B * C)
ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1)
# xs: (B * C, T, F) -> xs: (B * C, T, D)
xs, _, _ = self.brnn(xs, ilens_)
# xs: (B * C, T, D) -> xs: (B, C, T, D)
xs = xs.view(-1, C, xs.size(-2), xs.size(-1))
masks = []
for linear in self.linears:
# xs: (B, C, T, D) -> mask:(B, C, T, F)
mask = linear(xs)
mask = torch.sigmoid(mask)
# Zero padding
mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0)
# (B, C, T, F) -> (B, F, C, T)
mask = mask.permute(0, 3, 1, 2)
# Take cares of multi gpu cases: If input_length > max(ilens)
if mask.size(-1) < input_length:
mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0)
masks.append(mask)
return tuple(masks), ilens
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""GTN CTC implementation."""
import gtn
import torch
class GTNCTCLossFunction(torch.autograd.Function):
"""GTN CTC module."""
# Copied from FB's GTN example implementation:
# https://github.com/facebookresearch/gtn_applications/blob/master/utils.py#L251
@staticmethod
def create_ctc_graph(target, blank_idx):
"""Build gtn graph.
:param list target: single target sequence
:param int blank_idx: index of blank token
:return: gtn graph of target sequence
:rtype: gtn.Graph
"""
g_criterion = gtn.Graph(False)
L = len(target)
S = 2 * L + 1
for s in range(S):
idx = (s - 1) // 2
g_criterion.add_node(s == 0, s == S - 1 or s == S - 2)
label = target[idx] if s % 2 else blank_idx
g_criterion.add_arc(s, s, label)
if s > 0:
g_criterion.add_arc(s - 1, s, label)
if s % 2 and s > 1 and label != target[idx - 1]:
g_criterion.add_arc(s - 2, s, label)
g_criterion.arc_sort(False)
return g_criterion
@staticmethod
def forward(ctx, log_probs, targets, ilens, blank_idx=0, reduction="none"):
"""Forward computation.
:param torch.tensor log_probs: batched log softmax probabilities (B, Tmax, oDim)
:param list targets: batched target sequences, list of lists
:param int blank_idx: index of blank token
:return: ctc loss value
:rtype: torch.Tensor
"""
B, _, C = log_probs.shape
losses = [None] * B
scales = [None] * B
emissions_graphs = [None] * B
def process(b):
# create emission graph
T = ilens[b]
g_emissions = gtn.linear_graph(T, C, log_probs.requires_grad)
cpu_data = log_probs[b][:T].cpu().contiguous()
g_emissions.set_weights(cpu_data.data_ptr())
# create criterion graph
g_criterion = GTNCTCLossFunction.create_ctc_graph(targets[b], blank_idx)
# compose the graphs
g_loss = gtn.negate(
gtn.forward_score(gtn.intersect(g_emissions, g_criterion))
)
scale = 1.0
if reduction == "mean":
L = len(targets[b])
scale = 1.0 / L if L > 0 else scale
elif reduction != "none":
raise ValueError("invalid value for reduction '" + str(reduction) + "'")
# Save for backward:
losses[b] = g_loss
scales[b] = scale
emissions_graphs[b] = g_emissions
gtn.parallel_for(process, range(B))
ctx.auxiliary_data = (losses, scales, emissions_graphs, log_probs.shape, ilens)
loss = torch.tensor([losses[b].item() * scales[b] for b in range(B)])
return torch.mean(loss.cuda() if log_probs.is_cuda else loss)
@staticmethod
def backward(ctx, grad_output):
"""Backward computation.
:param torch.tensor grad_output: backward passed gradient value
:return: cumulative gradient output
:rtype: (torch.Tensor, None, None, None)
"""
losses, scales, emissions_graphs, in_shape, ilens = ctx.auxiliary_data
B, T, C = in_shape
input_grad = torch.zeros((B, T, C))
def process(b):
T = ilens[b]
gtn.backward(losses[b], False)
emissions = emissions_graphs[b]
grad = emissions.grad().weights_to_numpy()
input_grad[b][:T] = torch.from_numpy(grad).view(1, T, C) * scales[b]
gtn.parallel_for(process, range(B))
if grad_output.is_cuda:
input_grad = input_grad.cuda()
input_grad *= grad_output / B
return (
input_grad,
None, # targets
None, # ilens
None, # blank_idx
None, # reduction
)
#!/usr/bin/env python
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Initialization functions for RNN sequence-to-sequence models."""
import math
def lecun_normal_init_parameters(module):
"""Initialize parameters in the LeCun's manner."""
for p in module.parameters():
data = p.data
if data.dim() == 1:
# bias
data.zero_()
elif data.dim() == 2:
# linear weight
n = data.size(1)
stdv = 1.0 / math.sqrt(n)
data.normal_(0, stdv)
elif data.dim() in (3, 4):
# conv weight
n = data.size(1)
for k in data.size()[2:]:
n *= k
stdv = 1.0 / math.sqrt(n)
data.normal_(0, stdv)
else:
raise NotImplementedError
def uniform_init_parameters(module):
"""Initialize parameters with an uniform distribution."""
for p in module.parameters():
data = p.data
if data.dim() == 1:
# bias
data.uniform_(-0.1, 0.1)
elif data.dim() == 2:
# linear weight
data.uniform_(-0.1, 0.1)
elif data.dim() in (3, 4):
# conv weight
pass # use the pytorch default
else:
raise NotImplementedError
def set_forget_bias_to_one(bias):
"""Initialize a bias vector in the forget gate with one."""
n = bias.size(0)
start, end = n // 4, n // 2
bias.data[start:end].fill_(1.0)
"""Default Recurrent Neural Network Languge Model in `lm_train.py`."""
import logging
from typing import Any, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from espnet.nets.lm_interface import LMInterface
from espnet.nets.pytorch_backend.e2e_asr import to_device
from espnet.nets.scorer_interface import BatchScorerInterface
from espnet.utils.cli_utils import strtobool
class DefaultRNNLM(BatchScorerInterface, LMInterface, nn.Module):
"""Default RNNLM for `LMInterface` Implementation.
Note:
PyTorch seems to have memory leak when one GPU compute this after data parallel.
If parallel GPUs compute this, it seems to be fine.
See also https://github.com/espnet/espnet/issues/1075
"""
@staticmethod
def add_arguments(parser):
"""Add arguments to command line argument parser."""
parser.add_argument(
"--type",
type=str,
default="lstm",
nargs="?",
choices=["lstm", "gru"],
help="Which type of RNN to use",
)
parser.add_argument(
"--layer", "-l", type=int, default=2, help="Number of hidden layers"
)
parser.add_argument(
"--unit", "-u", type=int, default=650, help="Number of hidden units"
)
parser.add_argument(
"--embed-unit",
default=None,
type=int,
help="Number of hidden units in embedding layer, "
"if it is not specified, it keeps the same number with hidden units.",
)
parser.add_argument(
"--dropout-rate", type=float, default=0.5, help="dropout probability"
)
parser.add_argument(
"--emb-dropout-rate",
type=float,
default=0.0,
help="emb dropout probability",
)
parser.add_argument(
"--tie-weights",
type=strtobool,
default=False,
help="Tie input and output embeddings",
)
return parser
def __init__(self, n_vocab, args):
"""Initialize class.
Args:
n_vocab (int): The size of the vocabulary
args (argparse.Namespace): configurations. see py:method:`add_arguments`
"""
nn.Module.__init__(self)
# NOTE: for a compatibility with less than 0.5.0 version models
dropout_rate = getattr(args, "dropout_rate", 0.0)
# NOTE: for a compatibility with less than 0.6.1 version models
embed_unit = getattr(args, "embed_unit", None)
# NOTE: for a compatibility with less than 0.9.7 version models
emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0)
# NOTE: for a compatibility with less than 0.9.7 version models
tie_weights = getattr(args, "tie_weights", False)
self.model = ClassifierWithState(
RNNLM(
n_vocab,
args.layer,
args.unit,
embed_unit,
args.type,
dropout_rate,
emb_dropout_rate,
tie_weights,
)
)
def state_dict(self):
"""Dump state dict."""
return self.model.state_dict()
def load_state_dict(self, d):
"""Load state dict."""
self.model.load_state_dict(d)
def forward(self, x, t):
"""Compute LM loss value from buffer sequences.
Args:
x (torch.Tensor): Input ids. (batch, len)
t (torch.Tensor): Target ids. (batch, len)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
loss = 0
logp = 0
count = torch.tensor(0).long()
state = None
batch_size, sequence_length = x.shape
for i in range(sequence_length):
# Compute the loss at this time step and accumulate it
state, loss_batch = self.model(state, x[:, i], t[:, i])
non_zeros = torch.sum(x[:, i] != 0, dtype=loss_batch.dtype)
loss += loss_batch.mean() * non_zeros
logp += torch.sum(loss_batch * non_zeros)
count += int(non_zeros)
return loss / batch_size, loss, count.to(loss.device)
def score(self, y, state, x):
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (n_vocab)
and next state for ys
"""
new_state, scores = self.model.predict(state, y[-1].unsqueeze(0))
return scores.squeeze(0), new_state
def final_score(self, state):
"""Score eos.
Args:
state: Scorer state for prefix tokens
Returns:
float: final score
"""
return self.model.final(state)
# batch beam search API (see BatchScorerInterface)
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = self.model.predictor.n_layers
if self.model.predictor.typ == "lstm":
keys = ("c", "h")
else:
keys = ("h",)
if states[0] is None:
states = None
else:
# transpose state of [batch, key, layer] into [key, layer, batch]
states = {
k: [
torch.stack([states[b][k][i] for b in range(n_batch)])
for i in range(n_layers)
]
for k in keys
}
states, logp = self.model.predict(states, ys[:, -1])
# transpose state of [key, layer, batch] into [batch, key, layer]
return (
logp,
[
{k: [states[k][i][b] for i in range(n_layers)] for k in keys}
for b in range(n_batch)
],
)
class ClassifierWithState(nn.Module):
"""A wrapper for pytorch RNNLM."""
def __init__(
self, predictor, lossfun=nn.CrossEntropyLoss(reduction="none"), label_key=-1
):
"""Initialize class.
:param torch.nn.Module predictor : The RNNLM
:param function lossfun : The loss function to use
:param int/str label_key :
"""
if not (isinstance(label_key, (int, str))):
raise TypeError("label_key must be int or str, but is %s" % type(label_key))
super(ClassifierWithState, self).__init__()
self.lossfun = lossfun
self.y = None
self.loss = None
self.label_key = label_key
self.predictor = predictor
def forward(self, state, *args, **kwargs):
"""Compute the loss value for an input and label pair.
Notes:
It also computes accuracy and stores it to the attribute.
When ``label_key`` is ``int``, the corresponding element in ``args``
is treated as ground truth labels. And when it is ``str``, the
element in ``kwargs`` is used.
The all elements of ``args`` and ``kwargs`` except the groundtruth
labels are features.
It feeds features to the predictor and compare the result
with ground truth labels.
:param torch.Tensor state : the LM state
:param list[torch.Tensor] args : Input minibatch
:param dict[torch.Tensor] kwargs : Input minibatch
:return loss value
:rtype torch.Tensor
"""
if isinstance(self.label_key, int):
if not (-len(args) <= self.label_key < len(args)):
msg = "Label key %d is out of bounds" % self.label_key
raise ValueError(msg)
t = args[self.label_key]
if self.label_key == -1:
args = args[:-1]
else:
args = args[: self.label_key] + args[self.label_key + 1 :]
elif isinstance(self.label_key, str):
if self.label_key not in kwargs:
msg = 'Label key "%s" is not found' % self.label_key
raise ValueError(msg)
t = kwargs[self.label_key]
del kwargs[self.label_key]
self.y = None
self.loss = None
state, self.y = self.predictor(state, *args, **kwargs)
self.loss = self.lossfun(self.y, t)
return state, self.loss
def predict(self, state, x):
"""Predict log probabilities for given state and input x using the predictor.
:param torch.Tensor state : The current state
:param torch.Tensor x : The input
:return a tuple (new state, log prob vector)
:rtype (torch.Tensor, torch.Tensor)
"""
if hasattr(self.predictor, "normalized") and self.predictor.normalized:
return self.predictor(state, x)
else:
state, z = self.predictor(state, x)
return state, F.log_softmax(z, dim=1)
def buff_predict(self, state, x, n):
"""Predict new tokens from buffered inputs."""
if self.predictor.__class__.__name__ == "RNNLM":
return self.predict(state, x)
new_state = []
new_log_y = []
for i in range(n):
state_i = None if state is None else state[i]
state_i, log_y = self.predict(state_i, x[i].unsqueeze(0))
new_state.append(state_i)
new_log_y.append(log_y)
return new_state, torch.cat(new_log_y)
def final(self, state, index=None):
"""Predict final log probabilities for given state using the predictor.
:param state: The state
:return The final log probabilities
:rtype torch.Tensor
"""
if hasattr(self.predictor, "final"):
if index is not None:
return self.predictor.final(state[index])
else:
return self.predictor.final(state)
else:
return 0.0
# Definition of a recurrent net for language modeling
class RNNLM(nn.Module):
"""A pytorch RNNLM."""
def __init__(
self,
n_vocab,
n_layers,
n_units,
n_embed=None,
typ="lstm",
dropout_rate=0.5,
emb_dropout_rate=0.0,
tie_weights=False,
):
"""Initialize class.
:param int n_vocab: The size of the vocabulary
:param int n_layers: The number of layers to create
:param int n_units: The number of units per layer
:param str typ: The RNN type
"""
super(RNNLM, self).__init__()
if n_embed is None:
n_embed = n_units
self.embed = nn.Embedding(n_vocab, n_embed)
if emb_dropout_rate == 0.0:
self.embed_drop = None
else:
self.embed_drop = nn.Dropout(emb_dropout_rate)
if typ == "lstm":
self.rnn = nn.ModuleList(
[nn.LSTMCell(n_embed, n_units)]
+ [nn.LSTMCell(n_units, n_units) for _ in range(n_layers - 1)]
)
else:
self.rnn = nn.ModuleList(
[nn.GRUCell(n_embed, n_units)]
+ [nn.GRUCell(n_units, n_units) for _ in range(n_layers - 1)]
)
self.dropout = nn.ModuleList(
[nn.Dropout(dropout_rate) for _ in range(n_layers + 1)]
)
self.lo = nn.Linear(n_units, n_vocab)
self.n_layers = n_layers
self.n_units = n_units
self.typ = typ
logging.info("Tie weights set to {}".format(tie_weights))
logging.info("Dropout set to {}".format(dropout_rate))
logging.info("Emb Dropout set to {}".format(emb_dropout_rate))
if tie_weights:
assert (
n_embed == n_units
), "Tie Weights: True need embedding and final dimensions to match"
self.lo.weight = self.embed.weight
# initialize parameters from uniform distribution
for param in self.parameters():
param.data.uniform_(-0.1, 0.1)
def zero_state(self, batchsize):
"""Initialize state."""
p = next(self.parameters())
return torch.zeros(batchsize, self.n_units).to(device=p.device, dtype=p.dtype)
def forward(self, state, x):
"""Forward neural networks."""
if state is None:
h = [to_device(x, self.zero_state(x.size(0))) for n in range(self.n_layers)]
state = {"h": h}
if self.typ == "lstm":
c = [
to_device(x, self.zero_state(x.size(0)))
for n in range(self.n_layers)
]
state = {"c": c, "h": h}
h = [None] * self.n_layers
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(x))
else:
emb = self.embed(x)
if self.typ == "lstm":
c = [None] * self.n_layers
h[0], c[0] = self.rnn[0](
self.dropout[0](emb), (state["h"][0], state["c"][0])
)
for n in range(1, self.n_layers):
h[n], c[n] = self.rnn[n](
self.dropout[n](h[n - 1]), (state["h"][n], state["c"][n])
)
state = {"c": c, "h": h}
else:
h[0] = self.rnn[0](self.dropout[0](emb), state["h"][0])
for n in range(1, self.n_layers):
h[n] = self.rnn[n](self.dropout[n](h[n - 1]), state["h"][n])
state = {"h": h}
y = self.lo(self.dropout[-1](h[-1]))
return state, y
"""Sequential implementation of Recurrent Neural Network Language Model."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from espnet.nets.lm_interface import LMInterface
class SequentialRNNLM(LMInterface, torch.nn.Module):
"""Sequential RNNLM.
See also:
https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py
"""
@staticmethod
def add_arguments(parser):
"""Add arguments to command line argument parser."""
parser.add_argument(
"--type",
type=str,
default="lstm",
nargs="?",
choices=["lstm", "gru"],
help="Which type of RNN to use",
)
parser.add_argument(
"--layer", "-l", type=int, default=2, help="Number of hidden layers"
)
parser.add_argument(
"--unit", "-u", type=int, default=650, help="Number of hidden units"
)
parser.add_argument(
"--dropout-rate", type=float, default=0.5, help="dropout probability"
)
return parser
def __init__(self, n_vocab, args):
"""Initialize class.
Args:
n_vocab (int): The size of the vocabulary
args (argparse.Namespace): configurations. see py:method:`add_arguments`
"""
torch.nn.Module.__init__(self)
self._setup(
rnn_type=args.type.upper(),
ntoken=n_vocab,
ninp=args.unit,
nhid=args.unit,
nlayers=args.layer,
dropout=args.dropout_rate,
)
def _setup(
self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False
):
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp)
if rnn_type in ["LSTM", "GRU"]:
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
else:
try:
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
except KeyError:
raise ValueError(
"An invalid option for `--model` was supplied, "
"options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']"
)
self.rnn = nn.RNN(
ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout
)
self.decoder = nn.Linear(nhid, ntoken)
# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers:
# A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != ninp:
raise ValueError(
"When using the tied flag, nhid must be equal to emsize"
)
self.decoder.weight = self.encoder.weight
self._init_weights()
self.rnn_type = rnn_type
self.nhid = nhid
self.nlayers = nlayers
def _init_weights(self):
# NOTE: original init in pytorch/examples
# initrange = 0.1
# self.encoder.weight.data.uniform_(-initrange, initrange)
# self.decoder.bias.data.zero_()
# self.decoder.weight.data.uniform_(-initrange, initrange)
# NOTE: our default.py:RNNLM init
for param in self.parameters():
param.data.uniform_(-0.1, 0.1)
def forward(self, x, t):
"""Compute LM loss value from buffer sequences.
Args:
x (torch.Tensor): Input ids. (batch, len)
t (torch.Tensor): Target ids. (batch, len)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
y = self._before_loss(x, None)[0]
mask = (x != 0).to(y.dtype)
loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
logp = loss * mask.view(-1)
logp = logp.sum()
count = mask.sum()
return logp / count, logp, count
def _before_loss(self, input, hidden):
emb = self.drop(self.encoder(input))
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
decoded = self.decoder(
output.view(output.size(0) * output.size(1), output.size(2))
)
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden
def init_state(self, x):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
bsz = 1
weight = next(self.parameters())
if self.rnn_type == "LSTM":
return (
weight.new_zeros(self.nlayers, bsz, self.nhid),
weight.new_zeros(self.nlayers, bsz, self.nhid),
)
else:
return weight.new_zeros(self.nlayers, bsz, self.nhid)
def score(self, y, state, x):
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (n_vocab)
and next state for ys
"""
y, new_state = self._before_loss(y[-1].view(1, 1), state)
logp = y.log_softmax(dim=-1).view(-1)
return logp, new_state
"""Transformer language model."""
import logging
from typing import Any, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from espnet.nets.lm_interface import LMInterface
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.encoder import Encoder
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.scorer_interface import BatchScorerInterface
from espnet.utils.cli_utils import strtobool
class TransformerLM(nn.Module, LMInterface, BatchScorerInterface):
"""Transformer language model."""
@staticmethod
def add_arguments(parser):
"""Add arguments to command line argument parser."""
parser.add_argument(
"--layer", type=int, default=4, help="Number of hidden layers"
)
parser.add_argument(
"--unit",
type=int,
default=1024,
help="Number of hidden units in feedforward layer",
)
parser.add_argument(
"--att-unit",
type=int,
default=256,
help="Number of hidden units in attention layer",
)
parser.add_argument(
"--embed-unit",
type=int,
default=128,
help="Number of hidden units in embedding layer",
)
parser.add_argument(
"--head", type=int, default=2, help="Number of multi head attention"
)
parser.add_argument(
"--dropout-rate", type=float, default=0.5, help="dropout probability"
)
parser.add_argument(
"--att-dropout-rate",
type=float,
default=0.0,
help="att dropout probability",
)
parser.add_argument(
"--emb-dropout-rate",
type=float,
default=0.0,
help="emb dropout probability",
)
parser.add_argument(
"--tie-weights",
type=strtobool,
default=False,
help="Tie input and output embeddings",
)
parser.add_argument(
"--pos-enc",
default="sinusoidal",
choices=["sinusoidal", "none"],
help="positional encoding",
)
return parser
def __init__(self, n_vocab, args):
"""Initialize class.
Args:
n_vocab (int): The size of the vocabulary
args (argparse.Namespace): configurations. see py:method:`add_arguments`
"""
nn.Module.__init__(self)
# NOTE: for a compatibility with less than 0.9.7 version models
emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0)
# NOTE: for a compatibility with less than 0.9.7 version models
tie_weights = getattr(args, "tie_weights", False)
# NOTE: for a compatibility with less than 0.9.7 version models
att_dropout_rate = getattr(args, "att_dropout_rate", 0.0)
if args.pos_enc == "sinusoidal":
pos_enc_class = PositionalEncoding
elif args.pos_enc == "none":
def pos_enc_class(*args, **kwargs):
return nn.Sequential() # indentity
else:
raise ValueError(f"unknown pos-enc option: {args.pos_enc}")
self.embed = nn.Embedding(n_vocab, args.embed_unit)
if emb_dropout_rate == 0.0:
self.embed_drop = None
else:
self.embed_drop = nn.Dropout(emb_dropout_rate)
self.encoder = Encoder(
idim=args.embed_unit,
attention_dim=args.att_unit,
attention_heads=args.head,
linear_units=args.unit,
num_blocks=args.layer,
dropout_rate=args.dropout_rate,
attention_dropout_rate=att_dropout_rate,
input_layer="linear",
pos_enc_class=pos_enc_class,
)
self.decoder = nn.Linear(args.att_unit, n_vocab)
logging.info("Tie weights set to {}".format(tie_weights))
logging.info("Dropout set to {}".format(args.dropout_rate))
logging.info("Emb Dropout set to {}".format(emb_dropout_rate))
logging.info("Att Dropout set to {}".format(att_dropout_rate))
if tie_weights:
assert (
args.att_unit == args.embed_unit
), "Tie Weights: True need embedding and final dimensions to match"
self.decoder.weight = self.embed.weight
def _target_mask(self, ys_in_pad):
ys_mask = ys_in_pad != 0
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m
def forward(
self, x: torch.Tensor, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute LM loss value from buffer sequences.
Args:
x (torch.Tensor): Input ids. (batch, len)
t (torch.Tensor): Target ids. (batch, len)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
xm = x != 0
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(x))
else:
emb = self.embed(x)
h, _ = self.encoder(emb, self._target_mask(x))
y = self.decoder(h)
loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
mask = xm.to(dtype=loss.dtype)
logp = loss * mask.view(-1)
logp = logp.sum()
count = mask.sum()
return logp / count, logp, count
def score(
self, y: torch.Tensor, state: Any, x: torch.Tensor
) -> Tuple[torch.Tensor, Any]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (n_vocab)
and next state for ys
"""
y = y.unsqueeze(0)
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(y))
else:
emb = self.embed(y)
h, _, cache = self.encoder.forward_one_step(
emb, self._target_mask(y), cache=state
)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1).squeeze(0)
return logp, cache
# batch beam search API (see BatchScorerInterface)
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.encoder.encoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
torch.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
if self.embed_drop is not None:
emb = self.embed_drop(self.embed(ys))
else:
emb = self.embed(ys)
# batch decoding
h, _, states = self.encoder.forward_one_step(
emb, self._target_mask(ys), cache=batch_state
)
h = self.decoder(h[:, -1])
logp = h.log_softmax(dim=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Waseda University (Yosuke Higuchi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Token masking module for Masked LM."""
import numpy
def mask_uniform(ys_pad, mask_token, eos, ignore_id):
"""Replace random tokens with <mask> label and add <eos> label.
The number of <mask> is chosen from a uniform distribution
between one and the target sequence's length.
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:param int mask_token: index of <mask>
:param int eos: index of <eos>
:param int ignore_id: index of padding
:return: padded tensor (B, Lmax)
:rtype: torch.Tensor
:return: padded tensor (B, Lmax)
:rtype: torch.Tensor
"""
from espnet.nets.pytorch_backend.nets_utils import pad_list
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_out = [y.new(y.size()).fill_(ignore_id) for y in ys]
ys_in = [y.clone() for y in ys]
for i in range(len(ys)):
num_samples = numpy.random.randint(1, len(ys[i]) + 1)
idx = numpy.random.choice(len(ys[i]), num_samples)
ys_in[i][idx] = mask_token
ys_out[i][idx] = ys[i][idx]
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Waseda University (Yosuke Higuchi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Attention masking module for Masked LM."""
def square_mask(ys_in_pad, ignore_id):
"""Create attention mask to avoid attending on padding tokens.
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:param int ignore_id: index of padding
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor (B, Lmax, Lmax)
"""
ys_mask = (ys_in_pad != ignore_id).unsqueeze(-2)
ymax = ys_mask.size(-1)
ys_mask_tmp = ys_mask.transpose(1, 2).repeat(1, 1, ymax)
ys_mask = ys_mask.repeat(1, ymax, 1) & ys_mask_tmp
return ys_mask
# -*- coding: utf-8 -*-
"""Network related utility tools."""
import logging
from typing import Dict
import numpy as np
import torch
def to_device(m, x):
"""Send tensor into the device of the module.
Args:
m (torch.nn.Module): Torch module.
x (Tensor): Torch tensor.
Returns:
Tensor: Torch tensor located in the same place as torch module.
"""
if isinstance(m, torch.nn.Module):
device = next(m.parameters()).device
elif isinstance(m, torch.Tensor):
device = m.device
else:
raise TypeError(
"Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
)
return x.to(device)
def pad_list(xs, pad_value):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, : xs[i].size(0)] = xs[i]
return pad
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.long().tolist()
bs = int(len(lengths))
if maxlen is None:
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
else:
assert xs is None
assert maxlen >= int(max(lengths))
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
)
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return ~make_pad_mask(lengths, xs, length_dim)
def mask_by_length(xs, lengths, fill=0):
"""Mask tensor according to length.
Args:
xs (Tensor): Batch of input tensor (B, `*`).
lengths (LongTensor or List): Batch of lengths (B,).
fill (int or float): Value to fill masked part.
Returns:
Tensor: Batch of masked input tensor (B, `*`).
Examples:
>>> x = torch.arange(5).repeat(3, 1) + 1
>>> x
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]])
>>> lengths = [5, 3, 2]
>>> mask_by_length(x, lengths)
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 0, 0],
[1, 2, 0, 0, 0]])
"""
assert xs.size(0) == len(lengths)
ret = xs.data.new(*xs.size()).fill_(fill)
for i, l in enumerate(lengths):
ret[i, :l] = xs[i, :l]
return ret
def th_accuracy(pad_outputs, pad_targets, ignore_label):
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
).argmax(2)
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return float(numerator) / float(denominator)
def to_torch_tensor(x):
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
Args:
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
Returns:
Tensor or ComplexTensor: Type converted inputs.
Examples:
>>> xs = np.ones(3, dtype=np.float32)
>>> xs = to_torch_tensor(xs)
tensor([1., 1., 1.])
>>> xs = torch.ones(3, 4, 5)
>>> assert to_torch_tensor(xs) is xs
>>> xs = {'real': xs, 'imag': xs}
>>> to_torch_tensor(xs)
ComplexTensor(
Real:
tensor([1., 1., 1.])
Imag;
tensor([1., 1., 1.])
)
"""
# If numpy, change to torch tensor
if isinstance(x, np.ndarray):
if x.dtype.kind == "c":
# Dynamically importing because torch_complex requires python3
from torch_complex.tensor import ComplexTensor
return ComplexTensor(x)
else:
return torch.from_numpy(x)
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
elif isinstance(x, dict):
# Dynamically importing because torch_complex requires python3
from torch_complex.tensor import ComplexTensor
if "real" not in x or "imag" not in x:
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
# Relative importing because of using python3 syntax
return ComplexTensor(x["real"], x["imag"])
# If torch.Tensor, as it is
elif isinstance(x, torch.Tensor):
return x
else:
error = (
"x must be numpy.ndarray, torch.Tensor or a dict like "
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
"but got {}".format(type(x))
)
try:
from torch_complex.tensor import ComplexTensor
except Exception:
# If PY2
raise ValueError(error)
else:
# If PY3
if isinstance(x, ComplexTensor):
return x
else:
raise ValueError(error)
def get_subsample(train_args, mode, arch):
"""Parse the subsampling factors from the args for the specified `mode` and `arch`.
Args:
train_args: argument Namespace containing options.
mode: one of ('asr', 'mt', 'st')
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
Returns:
np.ndarray / List[np.ndarray]: subsampling factors.
"""
if arch == "transformer":
return np.array([1])
elif mode == "mt" and arch == "rnn":
# +1 means input (+1) and layers outputs (train_args.elayer)
subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
logging.warning("Subsampling is not performed for machine translation.")
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
return subsample
elif (
(mode == "asr" and arch in ("rnn", "rnn-t"))
or (mode == "mt" and arch == "rnn")
or (mode == "st" and arch == "rnn")
):
subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
ss = train_args.subsample.split("_")
for j in range(min(train_args.elayers + 1, len(ss))):
subsample[j] = int(ss[j])
else:
logging.warning(
"Subsampling is not performed for vgg*. "
"It is performed in max pooling layers at CNN."
)
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
return subsample
elif mode == "asr" and arch == "rnn_mix":
subsample = np.ones(
train_args.elayers_sd + train_args.elayers + 1, dtype=np.int64
)
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
ss = train_args.subsample.split("_")
for j in range(
min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
):
subsample[j] = int(ss[j])
else:
logging.warning(
"Subsampling is not performed for vgg*. "
"It is performed in max pooling layers at CNN."
)
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
return subsample
elif mode == "asr" and arch == "rnn_mulenc":
subsample_list = []
for idx in range(train_args.num_encs):
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int64)
if train_args.etype[idx].endswith("p") and not train_args.etype[
idx
].startswith("vgg"):
ss = train_args.subsample[idx].split("_")
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
subsample[j] = int(ss[j])
else:
logging.warning(
"Encoder %d: Subsampling is not performed for vgg*. "
"It is performed in max pooling layers at CNN.",
idx + 1,
)
logging.info("subsample: " + " ".join([str(x) for x in subsample]))
subsample_list.append(subsample)
return subsample_list
else:
raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
def rename_state_dict(
old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
):
"""Replace keys of old prefix with new prefix in state dict."""
# need this list not to break the dict iterator
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
if len(old_keys) > 0:
logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
for k in old_keys:
v = state_dict.pop(k)
new_k = k.replace(old_prefix, new_prefix)
state_dict[new_k] = v
def get_activation(act):
"""Return activation function."""
# Lazy load to avoid unused import
from espnet.nets.pytorch_backend.conformer.swish import Swish
activation_funcs = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": Swish,
}
return activation_funcs[act]()
# Copyright 2020 Hirofumi Inaguma
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Conformer common arguments."""
def add_arguments_rnn_encoder_common(group):
"""Define common arguments for RNN encoder."""
group.add_argument(
"--etype",
default="blstmp",
type=str,
choices=[
"lstm",
"blstm",
"lstmp",
"blstmp",
"vgglstmp",
"vggblstmp",
"vgglstm",
"vggblstm",
"gru",
"bgru",
"grup",
"bgrup",
"vgggrup",
"vggbgrup",
"vgggru",
"vggbgru",
],
help="Type of encoder network architecture",
)
group.add_argument(
"--elayers",
default=4,
type=int,
help="Number of encoder layers",
)
group.add_argument(
"--eunits",
"-u",
default=300,
type=int,
help="Number of encoder hidden units",
)
group.add_argument(
"--eprojs", default=320, type=int, help="Number of encoder projection units"
)
group.add_argument(
"--subsample",
default="1",
type=str,
help="Subsample input frames x_y_z means "
"subsample every x frame at 1st layer, "
"every y frame at 2nd layer etc.",
)
return group
def add_arguments_rnn_decoder_common(group):
"""Define common arguments for RNN decoder."""
group.add_argument(
"--dtype",
default="lstm",
type=str,
choices=["lstm", "gru"],
help="Type of decoder network architecture",
)
group.add_argument(
"--dlayers", default=1, type=int, help="Number of decoder layers"
)
group.add_argument(
"--dunits", default=320, type=int, help="Number of decoder hidden units"
)
group.add_argument(
"--dropout-rate-decoder",
default=0.0,
type=float,
help="Dropout rate for the decoder",
)
group.add_argument(
"--sampling-probability",
default=0.0,
type=float,
help="Ratio of predicted labels fed back to decoder",
)
group.add_argument(
"--lsm-type",
const="",
default="",
type=str,
nargs="?",
choices=["", "unigram"],
help="Apply label smoothing with a specified distribution type",
)
return group
def add_arguments_rnn_attention_common(group):
"""Define common arguments for RNN attention."""
group.add_argument(
"--atype",
default="dot",
type=str,
choices=[
"noatt",
"dot",
"add",
"location",
"coverage",
"coverage_location",
"location2d",
"location_recurrent",
"multi_head_dot",
"multi_head_add",
"multi_head_loc",
"multi_head_multi_res_loc",
],
help="Type of attention architecture",
)
group.add_argument(
"--adim",
default=320,
type=int,
help="Number of attention transformation dimensions",
)
group.add_argument(
"--awin", default=5, type=int, help="Window size for location2d attention"
)
group.add_argument(
"--aheads",
default=4,
type=int,
help="Number of heads for multi head attention",
)
group.add_argument(
"--aconv-chans",
default=-1,
type=int,
help="Number of attention convolution channels \
(negative value indicates no location-aware attention)",
)
group.add_argument(
"--aconv-filts",
default=100,
type=int,
help="Number of attention convolution filters \
(negative value indicates no location-aware attention)",
)
group.add_argument(
"--dropout-rate",
default=0.0,
type=float,
help="Dropout rate for the encoder",
)
return group
"""Attention modules for RNN."""
import math
import torch
import torch.nn.functional as F
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, to_device
def _apply_attention_constraint(
e, last_attended_idx, backward_window=1, forward_window=3
):
"""Apply monotonic attention constraint.
This function apply the monotonic attention constraint
introduced in `Deep Voice 3: Scaling
Text-to-Speech with Convolutional Sequence Learning`_.
Args:
e (Tensor): Attention energy before applying softmax (1, T).
last_attended_idx (int): The index of the inputs of the last attended [0, T].
backward_window (int, optional): Backward window size in attention constraint.
forward_window (int, optional): Forward window size in attetion constraint.
Returns:
Tensor: Monotonic constrained attention energy (1, T).
.. _`Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning`:
https://arxiv.org/abs/1710.07654
"""
if e.size(0) != 1:
raise NotImplementedError("Batch attention constraining is not yet supported.")
backward_idx = last_attended_idx - backward_window
forward_idx = last_attended_idx + forward_window
if backward_idx > 0:
e[:, :backward_idx] = -float("inf")
if forward_idx < e.size(1):
e[:, forward_idx:] = -float("inf")
return e
class NoAtt(torch.nn.Module):
"""No attention"""
def __init__(self):
super(NoAtt, self).__init__()
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.c = None
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.c = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""NoAtt forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, T_max, D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: dummy (does not use)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# initialize attention weight with uniform dist.
if att_prev is None:
# if no bias, 0 0-pad goes 0
mask = 1.0 - make_pad_mask(enc_hs_len).float()
att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1)
att_prev = att_prev.to(self.enc_h)
self.c = torch.sum(
self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1
)
return self.c, att_prev
class AttDot(torch.nn.Module):
"""Dot product attention
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttDot, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: dummy (does not use)
:param torch.Tensor att_prev: dummy (does not use)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weight (B x T_max)
:rtype: torch.Tensor
"""
batch = enc_hs_pad.size(0)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = torch.tanh(self.mlp_enc(self.enc_h))
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
e = torch.sum(
self.pre_compute_enc_h
* torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim),
dim=2,
) # utt x frame
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
class AttAdd(torch.nn.Module):
"""Additive attention
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttAdd, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled)).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
class AttLoc(torch.nn.Module):
"""location-aware attention module.
Reference: Attention-Based Models for Speech Recognition
(https://arxiv.org/pdf/1506.07503.pdf)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def __init__(
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
):
super(AttLoc, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(
self,
enc_hs_pad,
enc_hs_len,
dec_z,
att_prev,
scaling=2.0,
last_attended_idx=None,
backward_window=1,
forward_window=3,
):
"""Calculate AttLoc forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: previous attention weight (B x T_max)
:param float scaling: scaling parameter before applying softmax
:param torch.Tensor forward_window:
forward window size when constraining attention
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev is None:
# if no bias, 0 0-pad goes 0
att_prev = 1.0 - make_pad_mask(enc_hs_len).to(
device=dec_z.device, dtype=dec_z.dtype
)
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
# att_prev: utt x frame -> utt x 1 x 1 x frame
# -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE: consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(
e, last_attended_idx, backward_window, forward_window
)
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, w
class AttCov(torch.nn.Module):
"""Coverage mechanism attention
Reference: Get To The Point: Summarization with Pointer-Generator Network
(https://arxiv.org/abs/1704.04368)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def __init__(self, eprojs, dunits, att_dim, han_mode=False):
super(AttCov, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.wvec = torch.nn.Linear(1, att_dim)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0):
"""AttCov forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param list att_prev_list: list of previous attention weight
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weights
:rtype: list
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev_list is None:
# if no bias, 0 0-pad goes 0
att_prev_list = to_device(
enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())
)
att_prev_list = [
att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1)
]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec = sum(att_prev_list)
# cov_vec: B x T => B x T x 1 => B x T x att_dim
cov_vec = self.wvec(cov_vec.unsqueeze(-1))
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
att_prev_list += [w]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, att_prev_list
class AttLoc2D(torch.nn.Module):
"""2D location-aware attention
This attention is an extended version of location aware attention.
It take not only one frame before attention weights,
but also earlier frames into account.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param int att_win: attention window size (default=5)
:param bool han_mode:
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(
self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False
):
super(AttLoc2D, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(att_win, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.aconv_chans = aconv_chans
self.att_win = att_win
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttLoc2D forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: previous attention weight (B x att_win x T_max)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x att_win x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev is None:
# B * [Li x att_win]
# if no bias, 0 0-pad goes 0
att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()))
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
att_prev = att_prev.unsqueeze(1).expand(-1, self.att_win, -1)
# att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax
att_conv = self.loc_conv(att_prev.unsqueeze(1))
# att_conv: B x C x 1 x Tmax -> B x Tmax x C
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
# update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax
# -> B x att_win x Tmax
att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1)
att_prev = att_prev[:, 1:]
return c, att_prev
class AttLocRec(torch.nn.Module):
"""location-aware recurrent attention
This attention is an extended version of location aware attention.
With the use of RNN,
it take the effect of the history of attention weights into account.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode:
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
):
super(AttLocRec, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.att_lstm = torch.nn.LSTMCell(aconv_chans, att_dim, bias=False)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0):
"""AttLocRec forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param tuple att_prev_states: previous attention weight and lstm states
((B, T_max), ((B, att_dim), (B, att_dim)))
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights and lstm states (w, (hx, cx))
((B, T_max), ((B, att_dim), (B, att_dim)))
:rtype: tuple
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev_states is None:
# initialize attention weight with uniform dist.
# if no bias, 0 0-pad goes 0
att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()))
att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1)
# initialize lstm states
att_h = enc_hs_pad.new_zeros(batch, self.att_dim)
att_c = enc_hs_pad.new_zeros(batch, self.att_dim)
att_states = (att_h, att_c)
else:
att_prev = att_prev_states[0]
att_states = att_prev_states[1]
# B x 1 x 1 x T -> B x C x 1 x T
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# apply non-linear
att_conv = F.relu(att_conv)
# B x C x 1 x T -> B x C x 1 x 1 -> B x C
att_conv = F.max_pool2d(att_conv, (1, att_conv.size(3))).view(batch, -1)
att_h, att_c = self.att_lstm(att_conv, att_states)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(att_h.unsqueeze(1) + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, (w, (att_h, att_c))
class AttCovLoc(torch.nn.Module):
"""Coverage mechanism location aware attention
This attention is a combination of coverage and location-aware attentions.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode:
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def __init__(
self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False
):
super(AttCovLoc, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.aconv_chans = aconv_chans
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0):
"""AttCovLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param list att_prev_list: list of previous attention weight
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weights
:rtype: list
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
# initialize attention weight with uniform dist.
if att_prev_list is None:
# if no bias, 0 0-pad goes 0
mask = 1.0 - make_pad_mask(enc_hs_len).float()
att_prev_list = [
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec = sum(att_prev_list)
# cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T
att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w = F.softmax(scaling * e, dim=1)
att_prev_list += [w]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
return c, att_prev_list
class AttMultiHeadDot(torch.nn.Module):
"""Multi head dot product attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False):
super(AttMultiHeadDot, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
for _ in range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [
torch.tanh(self.mlp_k[h](self.enc_h)) for h in range(self.aheads)
]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
c = []
w = []
for h in range(self.aheads):
e = torch.sum(
self.pre_compute_k[h]
* torch.tanh(self.mlp_q[h](dec_z)).view(batch, 1, self.att_dim_k),
dim=2,
) # utt x frame
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [
torch.sum(
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
)
]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
class AttMultiHeadAdd(torch.nn.Module):
"""Multi head additive attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using additive attention for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False):
super(AttMultiHeadAdd, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
for _ in range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
c = []
w = []
for h in range(self.aheads):
e = self.gvec[h](
torch.tanh(
self.pre_compute_k[h]
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [
torch.sum(
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
)
]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
class AttMultiHeadLoc(torch.nn.Module):
"""Multi head location based attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using location-aware attention for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def __init__(
self,
eprojs,
dunits,
aheads,
att_dim_k,
att_dim_v,
aconv_chans,
aconv_filts,
han_mode=False,
):
super(AttMultiHeadLoc, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
self.loc_conv = torch.nn.ModuleList()
self.mlp_att = torch.nn.ModuleList()
for _ in range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
self.loc_conv += [
torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
]
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0):
"""AttMultiHeadLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev:
list of previous attention weight (B x T_max) * aheads
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
att_prev = []
for _ in range(self.aheads):
# if no bias, 0 0-pad goes 0
mask = 1.0 - make_pad_mask(enc_hs_len).float()
att_prev += [
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
]
c = []
w = []
for h in range(self.aheads):
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length))
att_conv = att_conv.squeeze(2).transpose(1, 2)
att_conv = self.mlp_att[h](att_conv)
e = self.gvec[h](
torch.tanh(
self.pre_compute_k[h]
+ att_conv
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w += [F.softmax(scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [
torch.sum(
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
)
]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
class AttMultiHeadMultiResLoc(torch.nn.Module):
"""Multi head multi resolution location based attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using location-aware attention for each head.
Furthermore, it uses different filter size for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param int aconv_chans: maximum # channels of attention convolution
each head use #ch = aconv_chans * (head + 1) / aheads
e.g. aheads=4, aconv_chans=100 => filter size = 25, 50, 75, 100
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def __init__(
self,
eprojs,
dunits,
aheads,
att_dim_k,
att_dim_v,
aconv_chans,
aconv_filts,
han_mode=False,
):
super(AttMultiHeadMultiResLoc, self).__init__()
self.mlp_q = torch.nn.ModuleList()
self.mlp_k = torch.nn.ModuleList()
self.mlp_v = torch.nn.ModuleList()
self.gvec = torch.nn.ModuleList()
self.loc_conv = torch.nn.ModuleList()
self.mlp_att = torch.nn.ModuleList()
for h in range(aheads):
self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)]
self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)]
self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)]
self.gvec += [torch.nn.Linear(att_dim_k, 1)]
afilts = aconv_filts * (h + 1) // aheads
self.loc_conv += [
torch.nn.Conv2d(
1, aconv_chans, (1, 2 * afilts + 1), padding=(0, afilts), bias=False
)
]
self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)]
self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False)
self.dunits = dunits
self.eprojs = eprojs
self.aheads = aheads
self.att_dim_k = att_dim_k
self.att_dim_v = att_dim_v
self.scaling = 1.0 / math.sqrt(att_dim_k)
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
self.han_mode = han_mode
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_k = None
self.pre_compute_v = None
self.mask = None
def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev):
"""AttMultiHeadMultiResLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: list of previous attention weight
(B x T_max) * aheads
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch = enc_hs_pad.size(0)
# pre-compute all k and v outside the decoder loop
if self.pre_compute_k is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_k = [self.mlp_k[h](self.enc_h) for h in range(self.aheads)]
if self.pre_compute_v is None or self.han_mode:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_v = [self.mlp_v[h](self.enc_h) for h in range(self.aheads)]
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
att_prev = []
for _ in range(self.aheads):
# if no bias, 0 0-pad goes 0
mask = 1.0 - make_pad_mask(enc_hs_len).float()
att_prev += [
to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1))
]
c = []
w = []
for h in range(self.aheads):
att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length))
att_conv = att_conv.squeeze(2).transpose(1, 2)
att_conv = self.mlp_att[h](att_conv)
e = self.gvec[h](
torch.tanh(
self.pre_compute_k[h]
+ att_conv
+ self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k)
)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
w += [F.softmax(self.scaling * e, dim=1)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c += [
torch.sum(
self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1
)
]
# concat all of c
c = self.mlp_o(torch.cat(c, dim=1))
return c, w
class AttForward(torch.nn.Module):
"""Forward attention module.
Reference:
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
(https://arxiv.org/pdf/1807.06736.pdf)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
"""
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts):
super(AttForward, self).__init__()
self.mlp_enc = torch.nn.Linear(eprojs, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eprojs = eprojs
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def reset(self):
"""reset states"""
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
def forward(
self,
enc_hs_pad,
enc_hs_len,
dec_z,
att_prev,
scaling=1.0,
last_attended_idx=None,
backward_window=1,
forward_window=3,
):
"""Calculate AttForward forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: attention weights of previous step
:param float scaling: scaling parameter before applying softmax
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
# initial attention will be [1, 0, 0, ...]
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2])
att_prev[:, 0] = 1.0
# att_prev: utt x frame -> utt x 1 x 1 x frame
# -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv)
).squeeze(2)
# NOTE: consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(
e, last_attended_idx, backward_window, forward_window
)
w = F.softmax(scaling * e, dim=1)
# forward attention
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1]
w = (att_prev + att_prev_shift) * w
# NOTE: clamp is needed to avoid nan gradient
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.unsqueeze(-1), dim=1)
return c, w
class AttForwardTA(torch.nn.Module):
"""Forward attention with transition agent module.
Reference:
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
(https://arxiv.org/pdf/1807.06736.pdf)
:param int eunits: # units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param int odim: output dimension
"""
def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim):
super(AttForwardTA, self).__init__()
self.mlp_enc = torch.nn.Linear(eunits, att_dim)
self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False)
self.mlp_ta = torch.nn.Linear(eunits + dunits + odim, 1)
self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False)
self.loc_conv = torch.nn.Conv2d(
1,
aconv_chans,
(1, 2 * aconv_filts + 1),
padding=(0, aconv_filts),
bias=False,
)
self.gvec = torch.nn.Linear(att_dim, 1)
self.dunits = dunits
self.eunits = eunits
self.att_dim = att_dim
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.trans_agent_prob = 0.5
def reset(self):
self.h_length = None
self.enc_h = None
self.pre_compute_enc_h = None
self.mask = None
self.trans_agent_prob = 0.5
def forward(
self,
enc_hs_pad,
enc_hs_len,
dec_z,
att_prev,
out_prev,
scaling=1.0,
last_attended_idx=None,
backward_window=1,
forward_window=3,
):
"""Calculate AttForwardTA forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, Tmax, eunits)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B, dunits)
:param torch.Tensor att_prev: attention weights of previous step
:param torch.Tensor out_prev: decoder outputs of previous step (B, odim)
:param float scaling: scaling parameter before applying softmax
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, dunits)
:rtype: torch.Tensor
:return: previous attention weights (B, Tmax)
:rtype: torch.Tensor
"""
batch = len(enc_hs_pad)
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = self.enc_h.size(1)
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
if dec_z is None:
dec_z = enc_hs_pad.new_zeros(batch, self.dunits)
else:
dec_z = dec_z.view(batch, self.dunits)
if att_prev is None:
# initial attention will be [1, 0, 0, ...]
att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2])
att_prev[:, 0] = 1.0
# att_prev: utt x frame -> utt x 1 x 1 x frame
# -> utt x att_conv_chans x 1 x frame
att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv = att_conv.squeeze(2).transpose(1, 2)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv = self.mlp_att(att_conv)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e = self.gvec(
torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled)
).squeeze(2)
# NOTE consider zero padding when compute w.
if self.mask is None:
self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len))
e.masked_fill_(self.mask, -float("inf"))
# apply monotonic attention constraint (mainly for TTS)
if last_attended_idx is not None:
e = _apply_attention_constraint(
e, last_attended_idx, backward_window, forward_window
)
w = F.softmax(scaling * e, dim=1)
# forward attention
att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1]
w = (
self.trans_agent_prob * att_prev
+ (1 - self.trans_agent_prob) * att_prev_shift
) * w
# NOTE: clamp is needed to avoid nan gradient
w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1)
# update transition agent prob
self.trans_agent_prob = torch.sigmoid(
self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1))
)
return c, w
def att_for(args, num_att=1, han_mode=False):
"""Instantiates an attention module given the program arguments
:param Namespace args: The arguments
:param int num_att: number of attention modules
(in multi-speaker case, it can be 2 or more)
:param bool han_mode: switch on/off mode of hierarchical attention network (HAN)
:rtype torch.nn.Module
:return: The attention module
"""
att_list = torch.nn.ModuleList()
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
aheads = getattr(args, "aheads", None)
awin = getattr(args, "awin", None)
aconv_chans = getattr(args, "aconv_chans", None)
aconv_filts = getattr(args, "aconv_filts", None)
if num_encs == 1:
for i in range(num_att):
att = initial_att(
args.atype,
args.eprojs,
args.dunits,
aheads,
args.adim,
awin,
aconv_chans,
aconv_filts,
)
att_list.append(att)
elif num_encs > 1: # no multi-speaker mode
if han_mode:
att = initial_att(
args.han_type,
args.eprojs,
args.dunits,
args.han_heads,
args.han_dim,
args.han_win,
args.han_conv_chans,
args.han_conv_filts,
han_mode=True,
)
return att
else:
att_list = torch.nn.ModuleList()
for idx in range(num_encs):
att = initial_att(
args.atype[idx],
args.eprojs,
args.dunits,
aheads[idx],
args.adim[idx],
awin[idx],
aconv_chans[idx],
aconv_filts[idx],
)
att_list.append(att)
else:
raise ValueError(
"Number of encoders needs to be more than one. {}".format(num_encs)
)
return att_list
def initial_att(
atype, eprojs, dunits, aheads, adim, awin, aconv_chans, aconv_filts, han_mode=False
):
"""Instantiates a single attention module
:param str atype: attention type
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int adim: attention dimension
:param int awin: attention window size
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
:return: The attention module
"""
if atype == "noatt":
att = NoAtt()
elif atype == "dot":
att = AttDot(eprojs, dunits, adim, han_mode)
elif atype == "add":
att = AttAdd(eprojs, dunits, adim, han_mode)
elif atype == "location":
att = AttLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode)
elif atype == "location2d":
att = AttLoc2D(eprojs, dunits, adim, awin, aconv_chans, aconv_filts, han_mode)
elif atype == "location_recurrent":
att = AttLocRec(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode)
elif atype == "coverage":
att = AttCov(eprojs, dunits, adim, han_mode)
elif atype == "coverage_location":
att = AttCovLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode)
elif atype == "multi_head_dot":
att = AttMultiHeadDot(eprojs, dunits, aheads, adim, adim, han_mode)
elif atype == "multi_head_add":
att = AttMultiHeadAdd(eprojs, dunits, aheads, adim, adim, han_mode)
elif atype == "multi_head_loc":
att = AttMultiHeadLoc(
eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode
)
elif atype == "multi_head_multi_res_loc":
att = AttMultiHeadMultiResLoc(
eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode
)
return att
def att_to_numpy(att_ws, att):
"""Converts attention weights to a numpy array given the attention
:param list att_ws: The attention weights
:param torch.nn.Module att: The attention
:rtype: np.ndarray
:return: The numpy array of the attention weights
"""
# convert to numpy array with the shape (B, Lmax, Tmax)
if isinstance(att, AttLoc2D):
# att_ws => list of previous concate attentions
att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy()
elif isinstance(att, (AttCov, AttCovLoc)):
# att_ws => list of list of previous attentions
att_ws = (
torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy()
)
elif isinstance(att, AttLocRec):
# att_ws => list of tuple of attention and hidden states
att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy()
elif isinstance(
att,
(AttMultiHeadDot, AttMultiHeadAdd, AttMultiHeadLoc, AttMultiHeadMultiResLoc),
):
# att_ws => list of list of each head attention
n_heads = len(att_ws[0])
att_ws_sorted_by_head = []
for h in range(n_heads):
att_ws_head = torch.stack([aw[h] for aw in att_ws], dim=1)
att_ws_sorted_by_head += [att_ws_head]
att_ws = torch.stack(att_ws_sorted_by_head, dim=1).cpu().numpy()
else:
# att_ws => list of attentions
att_ws = torch.stack(att_ws, dim=1).cpu().numpy()
return att_ws
"""RNN decoder module."""
import logging
import math
import random
from argparse import Namespace
import numpy as np
import torch
import torch.nn.functional as F
from espnet.nets.ctc_prefix_score import CTCPrefixScore, CTCPrefixScoreTH
from espnet.nets.e2e_asr_common import end_detect
from espnet.nets.pytorch_backend.nets_utils import (
mask_by_length,
pad_list,
th_accuracy,
to_device,
)
from espnet.nets.pytorch_backend.rnn.attentions import att_to_numpy
from espnet.nets.scorer_interface import ScorerInterface
MAX_DECODER_OUTPUT = 5
CTC_SCORING_RATIO = 1.5
class Decoder(torch.nn.Module, ScorerInterface):
"""Decoder module
:param int eprojs: encoder projection units
:param int odim: dimension of outputs
:param str dtype: gru or lstm
:param int dlayers: decoder layers
:param int dunits: decoder units
:param int sos: start of sequence symbol id
:param int eos: end of sequence symbol id
:param torch.nn.Module att: attention module
:param int verbose: verbose level
:param list char_list: list of character strings
:param ndarray labeldist: distribution of label smoothing
:param float lsm_weight: label smoothing weight
:param float sampling_probability: scheduled sampling probability
:param float dropout: dropout rate
:param float context_residual: if True, use context vector for token generation
:param float replace_sos: use for multilingual (speech/text) translation
"""
def __init__(
self,
eprojs,
odim,
dtype,
dlayers,
dunits,
sos,
eos,
att,
verbose=0,
char_list=None,
labeldist=None,
lsm_weight=0.0,
sampling_probability=0.0,
dropout=0.0,
context_residual=False,
replace_sos=False,
num_encs=1,
):
torch.nn.Module.__init__(self)
self.dtype = dtype
self.dunits = dunits
self.dlayers = dlayers
self.context_residual = context_residual
self.embed = torch.nn.Embedding(odim, dunits)
self.dropout_emb = torch.nn.Dropout(p=dropout)
self.decoder = torch.nn.ModuleList()
self.dropout_dec = torch.nn.ModuleList()
self.decoder += [
torch.nn.LSTMCell(dunits + eprojs, dunits)
if self.dtype == "lstm"
else torch.nn.GRUCell(dunits + eprojs, dunits)
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
for _ in range(1, self.dlayers):
self.decoder += [
torch.nn.LSTMCell(dunits, dunits)
if self.dtype == "lstm"
else torch.nn.GRUCell(dunits, dunits)
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
# NOTE: dropout is applied only for the vertical connections
# see https://arxiv.org/pdf/1409.2329.pdf
self.ignore_id = -1
if context_residual:
self.output = torch.nn.Linear(dunits + eprojs, odim)
else:
self.output = torch.nn.Linear(dunits, odim)
self.loss = None
self.att = att
self.dunits = dunits
self.sos = sos
self.eos = eos
self.odim = odim
self.verbose = verbose
self.char_list = char_list
# for label smoothing
self.labeldist = labeldist
self.vlabeldist = None
self.lsm_weight = lsm_weight
self.sampling_probability = sampling_probability
self.dropout = dropout
self.num_encs = num_encs
# for multilingual E2E-ST
self.replace_sos = replace_sos
self.logzero = -10000000000.0
def zero_state(self, hs_pad):
return hs_pad.new_zeros(hs_pad.size(0), self.dunits)
def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev):
if self.dtype == "lstm":
z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0]))
for i in range(1, self.dlayers):
z_list[i], c_list[i] = self.decoder[i](
self.dropout_dec[i - 1](z_list[i - 1]), (z_prev[i], c_prev[i])
)
else:
z_list[0] = self.decoder[0](ey, z_prev[0])
for i in range(1, self.dlayers):
z_list[i] = self.decoder[i](
self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
)
return z_list, c_list
def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
"""Decoder forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
[in multi-encoder case,
list of torch.Tensor,
[(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
[in multi-encoder case, list of torch.Tensor,
[(B), (B), ..., ]
:param torch.Tensor ys_pad: batch of padded character id sequence tensor
(B, Lmax)
:param int strm_idx: stream index indicates the index of decoding stream.
:param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy
:rtype: float
"""
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
hs_pad = [hs_pad]
hlens = [hlens]
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
# attention index for the attention module
# in SPA (speaker parallel attention),
# att_idx is used to select attention module. In other cases, it is 0.
att_idx = min(strm_idx, len(self.att) - 1)
# hlens should be list of list of integer
hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)]
self.loss = None
# prepare input and output word sequences with sos/eos IDs
eos = ys[0].new([self.eos])
sos = ys[0].new([self.sos])
if self.replace_sos:
ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
else:
ys_in = [torch.cat([sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, eos], dim=0) for y in ys]
# padding for ys with -1
# pys: utt x olen
ys_in_pad = pad_list(ys_in, self.eos)
ys_out_pad = pad_list(ys_out, self.ignore_id)
# get dim, length info
batch = ys_out_pad.size(0)
olength = ys_out_pad.size(1)
for idx in range(self.num_encs):
logging.info(
self.__class__.__name__
+ "Number of Encoder:{}; enc{}: input lengths: {}.".format(
self.num_encs, idx + 1, hlens[idx]
)
)
logging.info(
self.__class__.__name__
+ " output lengths: "
+ str([y.size(0) for y in ys_out])
)
# initialization
c_list = [self.zero_state(hs_pad[0])]
z_list = [self.zero_state(hs_pad[0])]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad[0]))
z_list.append(self.zero_state(hs_pad[0]))
z_all = []
if self.num_encs == 1:
att_w = None
self.att[att_idx].reset() # reset pre-computation of h
else:
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in range(olength):
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](
hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w
)
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](
hs_pad[idx],
hlens[idx],
self.dropout_dec[0](z_list[0]),
att_w_list[idx],
)
hs_pad_han = torch.stack(att_c_list, dim=1)
hlens_han = [self.num_encs] * len(ys_in)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
hs_pad_han,
hlens_han,
self.dropout_dec[0](z_list[0]),
att_w_list[self.num_encs],
)
if i > 0 and random.random() < self.sampling_probability:
logging.info(" scheduled sampling ")
z_out = self.output(z_all[-1])
z_out = np.argmax(z_out.detach().cpu(), axis=1)
z_out = self.dropout_emb(self.embed(to_device(hs_pad[0], z_out)))
ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim)
else:
ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.context_residual:
z_all.append(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
) # utt x (zdim + hdim)
else:
z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim)
z_all = torch.stack(z_all, dim=1).view(batch * olength, -1)
# compute loss
y_all = self.output(z_all)
self.loss = F.cross_entropy(
y_all,
ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction="mean",
)
# compute perplexity
ppl = math.exp(self.loss.item())
# -1: eos, which is removed in the loss computation
self.loss *= np.mean([len(x) for x in ys_in]) - 1
acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id)
logging.info("att loss:" + "".join(str(self.loss.item()).split("\n")))
# show predicted character sequence for debug
if self.verbose > 0 and self.char_list is not None:
ys_hat = y_all.view(batch, olength, -1)
ys_true = ys_out_pad
for (i, y_hat), y_true in zip(
enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()
):
if i == MAX_DECODER_OUTPUT:
break
idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1)
idx_true = y_true[y_true != self.ignore_id]
seq_hat = [self.char_list[int(idx)] for idx in idx_hat]
seq_true = [self.char_list[int(idx)] for idx in idx_true]
seq_hat = "".join(seq_hat)
seq_true = "".join(seq_true)
logging.info("groundtruth[%d]: " % i + seq_true)
logging.info("prediction [%d]: " % i + seq_hat)
if self.labeldist is not None:
if self.vlabeldist is None:
self.vlabeldist = to_device(hs_pad[0], torch.from_numpy(self.labeldist))
loss_reg = -torch.sum(
(F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0
) / len(ys_in)
self.loss = (1.0 - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg
return self.loss, acc, ppl
def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0):
"""beam search implementation
:param torch.Tensor h: encoder hidden state (T, eprojs)
[in multi-encoder case, list of torch.Tensor,
[(T1, eprojs), (T2, eprojs), ...] ]
:param torch.Tensor lpz: ctc log softmax output (T, odim)
[in multi-encoder case, list of torch.Tensor,
[(T1, odim), (T2, odim), ...] ]
:param Namespace recog_args: argument Namespace containing options
:param char_list: list of character strings
:param torch.nn.Module rnnlm: language module
:param int strm_idx:
stream index for speaker parallel attention in multi-speaker case
:return: N-best decoding results
:rtype: list of dicts
"""
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
h = [h]
lpz = [lpz]
if self.num_encs > 1 and lpz is None:
lpz = [lpz] * self.num_encs
for idx in range(self.num_encs):
logging.info(
"Number of Encoder:{}; enc{}: input lengths: {}.".format(
self.num_encs, idx + 1, h[0].size(0)
)
)
att_idx = min(strm_idx, len(self.att) - 1)
# initialization
c_list = [self.zero_state(h[0].unsqueeze(0))]
z_list = [self.zero_state(h[0].unsqueeze(0))]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(h[0].unsqueeze(0)))
z_list.append(self.zero_state(h[0].unsqueeze(0)))
if self.num_encs == 1:
a = None
self.att[att_idx].reset() # reset pre-computation of h
else:
a = [None] * (self.num_encs + 1) # atts + han
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
# search parms
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT
if lpz[0] is not None and self.num_encs > 1:
# weights-ctc,
# e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
recog_args.weights_ctc_dec
) # normalize
logging.info(
"ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
)
else:
weights_ctc_dec = [1.0]
# preprate sos
if self.replace_sos and recog_args.tgt_lang:
y = char_list.index(recog_args.tgt_lang)
else:
y = self.sos
logging.info("<sos> index: " + str(y))
logging.info("<sos> mark: " + char_list[y])
vy = h[0].new_zeros(1).long()
maxlen = np.amin([h[idx].size(0) for idx in range(self.num_encs)])
if recog_args.maxlenratio != 0:
# maxlen >= 1
maxlen = max(1, int(recog_args.maxlenratio * maxlen))
minlen = int(recog_args.minlenratio * maxlen)
logging.info("max output length: " + str(maxlen))
logging.info("min output length: " + str(minlen))
# initialize hypothesis
if rnnlm:
hyp = {
"score": 0.0,
"yseq": [y],
"c_prev": c_list,
"z_prev": z_list,
"a_prev": a,
"rnnlm_prev": None,
}
else:
hyp = {
"score": 0.0,
"yseq": [y],
"c_prev": c_list,
"z_prev": z_list,
"a_prev": a,
}
if lpz[0] is not None:
ctc_prefix_score = [
CTCPrefixScore(lpz[idx].detach().numpy(), 0, self.eos, np)
for idx in range(self.num_encs)
]
hyp["ctc_state_prev"] = [
ctc_prefix_score[idx].initial_state() for idx in range(self.num_encs)
]
hyp["ctc_score_prev"] = [0.0] * self.num_encs
if ctc_weight != 1.0:
# pre-pruning based on attention scores
ctc_beam = min(lpz[0].shape[-1], int(beam * CTC_SCORING_RATIO))
else:
ctc_beam = lpz[0].shape[-1]
hyps = [hyp]
ended_hyps = []
for i in range(maxlen):
logging.debug("position " + str(i))
hyps_best_kept = []
for hyp in hyps:
vy[0] = hyp["yseq"][i]
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](
h[0].unsqueeze(0),
[h[0].size(0)],
self.dropout_dec[0](hyp["z_prev"][0]),
hyp["a_prev"],
)
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](
h[idx].unsqueeze(0),
[h[idx].size(0)],
self.dropout_dec[0](hyp["z_prev"][0]),
hyp["a_prev"][idx],
)
h_han = torch.stack(att_c_list, dim=1)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
h_han,
[self.num_encs],
self.dropout_dec[0](hyp["z_prev"][0]),
hyp["a_prev"][self.num_encs],
)
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(
ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"]
)
# get nbest local scores and their ids
if self.context_residual:
logits = self.output(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
)
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
local_att_scores = F.log_softmax(logits, dim=1)
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy)
local_scores = (
local_att_scores + recog_args.lm_weight * local_lm_scores
)
else:
local_scores = local_att_scores
if lpz[0] is not None:
local_best_scores, local_best_ids = torch.topk(
local_att_scores, ctc_beam, dim=1
)
ctc_scores, ctc_states = (
[None] * self.num_encs,
[None] * self.num_encs,
)
for idx in range(self.num_encs):
ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx](
hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx]
)
local_scores = (1.0 - ctc_weight) * local_att_scores[
:, local_best_ids[0]
]
if self.num_encs == 1:
local_scores += ctc_weight * torch.from_numpy(
ctc_scores[0] - hyp["ctc_score_prev"][0]
)
else:
for idx in range(self.num_encs):
local_scores += (
ctc_weight
* weights_ctc_dec[idx]
* torch.from_numpy(
ctc_scores[idx] - hyp["ctc_score_prev"][idx]
)
)
if rnnlm:
local_scores += (
recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
)
local_best_scores, joint_best_ids = torch.topk(
local_scores, beam, dim=1
)
local_best_ids = local_best_ids[:, joint_best_ids[0]]
else:
local_best_scores, local_best_ids = torch.topk(
local_scores, beam, dim=1
)
for j in range(beam):
new_hyp = {}
# [:] is needed!
new_hyp["z_prev"] = z_list[:]
new_hyp["c_prev"] = c_list[:]
if self.num_encs == 1:
new_hyp["a_prev"] = att_w[:]
else:
new_hyp["a_prev"] = [
att_w_list[idx][:] for idx in range(self.num_encs + 1)
]
new_hyp["score"] = hyp["score"] + local_best_scores[0, j]
new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j])
if rnnlm:
new_hyp["rnnlm_prev"] = rnnlm_state
if lpz[0] is not None:
new_hyp["ctc_state_prev"] = [
ctc_states[idx][joint_best_ids[0, j]]
for idx in range(self.num_encs)
]
new_hyp["ctc_score_prev"] = [
ctc_scores[idx][joint_best_ids[0, j]]
for idx in range(self.num_encs)
]
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(
hyps_best_kept, key=lambda x: x["score"], reverse=True
)[:beam]
# sort and get nbest
hyps = hyps_best_kept
logging.debug("number of pruned hypotheses: " + str(len(hyps)))
logging.debug(
"best hypo: "
+ "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])
)
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
logging.info("adding <eos> in the last position in the loop")
for hyp in hyps:
hyp["yseq"].append(self.eos)
# add ended hypotheses to a final list,
# and removed them from current hypotheses
# (this will be a problem, number of hyps < beam)
remained_hyps = []
for hyp in hyps:
if hyp["yseq"][-1] == self.eos:
# only store the sequence that has more than minlen outputs
# also add penalty
if len(hyp["yseq"]) > minlen:
hyp["score"] += (i + 1) * penalty
if rnnlm: # Word LM needs to add final <eos> score
hyp["score"] += recog_args.lm_weight * rnnlm.final(
hyp["rnnlm_prev"]
)
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
# end detection
if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0:
logging.info("end detected at %d", i)
break
hyps = remained_hyps
if len(hyps) > 0:
logging.debug("remaining hypotheses: " + str(len(hyps)))
else:
logging.info("no hypothesis. Finish decoding.")
break
for hyp in hyps:
logging.debug(
"hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])
)
logging.debug("number of ended hypotheses: " + str(len(ended_hyps)))
nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[
: min(len(ended_hyps), recog_args.nbest)
]
# check number of hypotheses
if len(nbest_hyps) == 0:
logging.warning(
"there is no N-best results, "
"perform recognition again with smaller minlenratio."
)
# should copy because Namespace will be overwritten globally
recog_args = Namespace(**vars(recog_args))
recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1)
if self.num_encs == 1:
return self.recognize_beam(h[0], lpz[0], recog_args, char_list, rnnlm)
else:
return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm)
logging.info("total log probability: " + str(nbest_hyps[0]["score"]))
logging.info(
"normalized log probability: "
+ str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"]))
)
# remove sos
return nbest_hyps
def recognize_beam_batch(
self,
h,
hlens,
lpz,
recog_args,
char_list,
rnnlm=None,
normalize_score=True,
strm_idx=0,
lang_ids=None,
):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
h = [h]
hlens = [hlens]
lpz = [lpz]
if self.num_encs > 1 and lpz is None:
lpz = [lpz] * self.num_encs
att_idx = min(strm_idx, len(self.att) - 1)
for idx in range(self.num_encs):
logging.info(
"Number of Encoder:{}; enc{}: input lengths: {}.".format(
self.num_encs, idx + 1, h[idx].size(1)
)
)
h[idx] = mask_by_length(h[idx], hlens[idx], 0.0)
# search params
batch = len(hlens[0])
beam = recog_args.beam_size
penalty = recog_args.penalty
ctc_weight = getattr(recog_args, "ctc_weight", 0) # for NMT
att_weight = 1.0 - ctc_weight
ctc_margin = getattr(
recog_args, "ctc_window_margin", 0
) # use getattr to keep compatibility
# weights-ctc,
# e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
if lpz[0] is not None and self.num_encs > 1:
weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
recog_args.weights_ctc_dec
) # normalize
logging.info(
"ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
)
else:
weights_ctc_dec = [1.0]
n_bb = batch * beam
pad_b = to_device(h[0], torch.arange(batch) * beam).view(-1, 1)
max_hlen = np.amin([max(hlens[idx]) for idx in range(self.num_encs)])
if recog_args.maxlenratio == 0:
maxlen = max_hlen
else:
maxlen = max(1, int(recog_args.maxlenratio * max_hlen))
minlen = int(recog_args.minlenratio * max_hlen)
logging.info("max output length: " + str(maxlen))
logging.info("min output length: " + str(minlen))
# initialization
c_prev = [
to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
]
z_prev = [
to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
]
c_list = [
to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
]
z_list = [
to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
]
vscores = to_device(h[0], torch.zeros(batch, beam))
rnnlm_state = None
if self.num_encs == 1:
a_prev = [None]
att_w_list, ctc_scorer, ctc_state = [None], [None], [None]
self.att[att_idx].reset() # reset pre-computation of h
else:
a_prev = [None] * (self.num_encs + 1) # atts + han
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
ctc_scorer, ctc_state = [None] * (self.num_encs), [None] * (self.num_encs)
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
if self.replace_sos and recog_args.tgt_lang:
logging.info("<sos> index: " + str(char_list.index(recog_args.tgt_lang)))
logging.info("<sos> mark: " + recog_args.tgt_lang)
yseq = [[char_list.index(recog_args.tgt_lang)] for _ in range(n_bb)]
elif lang_ids is not None:
# NOTE: used for evaluation during training
yseq = [[lang_ids[b // recog_args.beam_size]] for b in range(n_bb)]
else:
logging.info("<sos> index: " + str(self.sos))
logging.info("<sos> mark: " + char_list[self.sos])
yseq = [[self.sos] for _ in range(n_bb)]
accum_odim_ids = [self.sos for _ in range(n_bb)]
stop_search = [False for _ in range(batch)]
nbest_hyps = [[] for _ in range(batch)]
ended_hyps = [[] for _ in range(batch)]
exp_hlens = [
hlens[idx].repeat(beam).view(beam, batch).transpose(0, 1).contiguous()
for idx in range(self.num_encs)
]
exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)]
exp_h = [
h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous()
for idx in range(self.num_encs)
]
exp_h = [
exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2])
for idx in range(self.num_encs)
]
if lpz[0] is not None:
scoring_num = min(
int(beam * CTC_SCORING_RATIO)
if att_weight > 0.0 and not lpz[0].is_cuda
else 0,
lpz[0].size(-1),
)
ctc_scorer = [
CTCPrefixScoreTH(
lpz[idx],
hlens[idx],
0,
self.eos,
margin=ctc_margin,
)
for idx in range(self.num_encs)
]
for i in range(maxlen):
logging.debug("position " + str(i))
vy = to_device(h[0], torch.LongTensor(self._get_last_yseq(yseq)))
ey = self.dropout_emb(self.embed(vy))
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](
exp_h[0], exp_hlens[0], self.dropout_dec[0](z_prev[0]), a_prev[0]
)
att_w_list = [att_w]
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](
exp_h[idx],
exp_hlens[idx],
self.dropout_dec[0](z_prev[0]),
a_prev[idx],
)
exp_h_han = torch.stack(att_c_list, dim=1)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
exp_h_han,
[self.num_encs] * n_bb,
self.dropout_dec[0](z_prev[0]),
a_prev[self.num_encs],
)
ey = torch.cat((ey, att_c), dim=1)
# attention decoder
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
if self.context_residual:
logits = self.output(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
)
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
local_scores = att_weight * F.log_softmax(logits, dim=1)
# rnnlm
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.buff_predict(rnnlm_state, vy, n_bb)
local_scores = local_scores + recog_args.lm_weight * local_lm_scores
# ctc
if ctc_scorer[0]:
local_scores[:, 0] = self.logzero # avoid choosing blank
part_ids = (
torch.topk(local_scores, scoring_num, dim=-1)[1]
if scoring_num > 0
else None
)
for idx in range(self.num_encs):
att_w = att_w_list[idx]
att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0]
local_ctc_scores, ctc_state[idx] = ctc_scorer[idx](
yseq, ctc_state[idx], part_ids, att_w_
)
local_scores = (
local_scores
+ ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
)
local_scores = local_scores.view(batch, beam, self.odim)
if i == 0:
local_scores[:, 1:, :] = self.logzero
# accumulate scores
eos_vscores = local_scores[:, :, self.eos] + vscores
vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim)
vscores[:, :, self.eos] = self.logzero
vscores = (vscores + local_scores).view(batch, -1)
# global pruning
accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1)
accum_odim_ids = (
torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
)
accum_padded_beam_ids = (
(accum_best_ids // self.odim + pad_b).view(-1).data.cpu().tolist()
)
y_prev = yseq[:][:]
yseq = self._index_select_list(yseq, accum_padded_beam_ids)
yseq = self._append_ids(yseq, accum_odim_ids)
vscores = accum_best_scores
vidx = to_device(h[0], torch.LongTensor(accum_padded_beam_ids))
a_prev = []
num_atts = self.num_encs if self.num_encs == 1 else self.num_encs + 1
for idx in range(num_atts):
if isinstance(att_w_list[idx], torch.Tensor):
_a_prev = torch.index_select(
att_w_list[idx].view(n_bb, *att_w_list[idx].shape[1:]), 0, vidx
)
elif isinstance(att_w_list[idx], list):
# handle the case of multi-head attention
_a_prev = [
torch.index_select(att_w_one.view(n_bb, -1), 0, vidx)
for att_w_one in att_w_list[idx]
]
else:
# handle the case of location_recurrent when return is a tuple
_a_prev_ = torch.index_select(
att_w_list[idx][0].view(n_bb, -1), 0, vidx
)
_h_prev_ = torch.index_select(
att_w_list[idx][1][0].view(n_bb, -1), 0, vidx
)
_c_prev_ = torch.index_select(
att_w_list[idx][1][1].view(n_bb, -1), 0, vidx
)
_a_prev = (_a_prev_, (_h_prev_, _c_prev_))
a_prev.append(_a_prev)
z_prev = [
torch.index_select(z_list[li].view(n_bb, -1), 0, vidx)
for li in range(self.dlayers)
]
c_prev = [
torch.index_select(c_list[li].view(n_bb, -1), 0, vidx)
for li in range(self.dlayers)
]
# pick ended hyps
if i >= minlen:
k = 0
penalty_i = (i + 1) * penalty
thr = accum_best_scores[:, -1]
for samp_i in range(batch):
if stop_search[samp_i]:
k = k + beam
continue
for beam_j in range(beam):
_vscore = None
if eos_vscores[samp_i, beam_j] > thr[samp_i]:
yk = y_prev[k][:]
if len(yk) <= min(
hlens[idx][samp_i] for idx in range(self.num_encs)
):
_vscore = eos_vscores[samp_i][beam_j] + penalty_i
elif i == maxlen - 1:
yk = yseq[k][:]
_vscore = vscores[samp_i][beam_j] + penalty_i
if _vscore:
yk.append(self.eos)
if rnnlm:
_vscore += recog_args.lm_weight * rnnlm.final(
rnnlm_state, index=k
)
_score = _vscore.data.cpu().numpy()
ended_hyps[samp_i].append(
{"yseq": yk, "vscore": _vscore, "score": _score}
)
k = k + 1
# end detection
stop_search = [
stop_search[samp_i] or end_detect(ended_hyps[samp_i], i)
for samp_i in range(batch)
]
stop_search_summary = list(set(stop_search))
if len(stop_search_summary) == 1 and stop_search_summary[0]:
break
if rnnlm:
rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx)
if ctc_scorer[0]:
for idx in range(self.num_encs):
ctc_state[idx] = ctc_scorer[idx].index_select_state(
ctc_state[idx], accum_best_ids
)
torch.cuda.empty_cache()
dummy_hyps = [
{"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}
]
ended_hyps = [
ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps
for samp_i in range(batch)
]
if normalize_score:
for samp_i in range(batch):
for x in ended_hyps[samp_i]:
x["score"] /= len(x["yseq"])
nbest_hyps = [
sorted(ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[
: min(len(ended_hyps[samp_i]), recog_args.nbest)
]
for samp_i in range(batch)
]
return nbest_hyps
def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, lang_ids=None):
"""Calculate all of attentions
:param torch.Tensor hs_pad: batch of padded hidden state sequences
(B, Tmax, D)
in multi-encoder case, list of torch.Tensor,
[(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
:param torch.Tensor hlen: batch of lengths of hidden state sequences (B)
[in multi-encoder case, list of torch.Tensor,
[(B), (B), ..., ]
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, Lmax)
:param int strm_idx:
stream index for parallel speaker attention in multi-speaker case
:param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) multi-encoder case =>
[(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)]
3) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
hs_pad = [hs_pad]
hlen = [hlen]
# TODO(kan-bayashi): need to make more smart way
ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys
att_idx = min(strm_idx, len(self.att) - 1)
# hlen should be list of list of integer
hlen = [list(map(int, hlen[idx])) for idx in range(self.num_encs)]
self.loss = None
# prepare input and output word sequences with sos/eos IDs
eos = ys[0].new([self.eos])
sos = ys[0].new([self.sos])
if self.replace_sos:
ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)]
else:
ys_in = [torch.cat([sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, eos], dim=0) for y in ys]
# padding for ys with -1
# pys: utt x olen
ys_in_pad = pad_list(ys_in, self.eos)
ys_out_pad = pad_list(ys_out, self.ignore_id)
# get length info
olength = ys_out_pad.size(1)
# initialization
c_list = [self.zero_state(hs_pad[0])]
z_list = [self.zero_state(hs_pad[0])]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(hs_pad[0]))
z_list.append(self.zero_state(hs_pad[0]))
att_ws = []
if self.num_encs == 1:
att_w = None
self.att[att_idx].reset() # reset pre-computation of h
else:
att_w_list = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
# pre-computation of embedding
eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim
# loop for an output sequence
for i in range(olength):
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](
hs_pad[0], hlen[0], self.dropout_dec[0](z_list[0]), att_w
)
att_ws.append(att_w)
else:
for idx in range(self.num_encs):
att_c_list[idx], att_w_list[idx] = self.att[idx](
hs_pad[idx],
hlen[idx],
self.dropout_dec[0](z_list[0]),
att_w_list[idx],
)
hs_pad_han = torch.stack(att_c_list, dim=1)
hlen_han = [self.num_encs] * len(ys_in)
att_c, att_w_list[self.num_encs] = self.att[self.num_encs](
hs_pad_han,
hlen_han,
self.dropout_dec[0](z_list[0]),
att_w_list[self.num_encs],
)
att_ws.append(att_w_list.copy())
ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim)
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list)
if self.num_encs == 1:
# convert to numpy array with the shape (B, Lmax, Tmax)
att_ws = att_to_numpy(att_ws, self.att[att_idx])
else:
_att_ws = []
for idx, ws in enumerate(zip(*att_ws)):
ws = att_to_numpy(ws, self.att[idx])
_att_ws.append(ws)
att_ws = _att_ws
return att_ws
@staticmethod
def _get_last_yseq(exp_yseq):
last = []
for y_seq in exp_yseq:
last.append(y_seq[-1])
return last
@staticmethod
def _append_ids(yseq, ids):
if isinstance(ids, list):
for i, j in enumerate(ids):
yseq[i].append(j)
else:
for i in range(len(yseq)):
yseq[i].append(ids)
return yseq
@staticmethod
def _index_select_list(yseq, lst):
new_yseq = []
for i in lst:
new_yseq.append(yseq[i][:])
return new_yseq
@staticmethod
def _index_select_lm_state(rnnlm_state, dim, vidx):
if isinstance(rnnlm_state, dict):
new_state = {}
for k, v in rnnlm_state.items():
new_state[k] = [torch.index_select(vi, dim, vidx) for vi in v]
elif isinstance(rnnlm_state, list):
new_state = []
for i in vidx:
new_state.append(rnnlm_state[int(i)][:])
return new_state
# scorer interface methods
def init_state(self, x):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
c_list = [self.zero_state(x[0].unsqueeze(0))]
z_list = [self.zero_state(x[0].unsqueeze(0))]
for _ in range(1, self.dlayers):
c_list.append(self.zero_state(x[0].unsqueeze(0)))
z_list.append(self.zero_state(x[0].unsqueeze(0)))
# TODO(karita): support strm_index for `asr_mix`
strm_index = 0
att_idx = min(strm_index, len(self.att) - 1)
if self.num_encs == 1:
a = None
self.att[att_idx].reset() # reset pre-computation of h
else:
a = [None] * (self.num_encs + 1) # atts + han
for idx in range(self.num_encs + 1):
self.att[idx].reset() # reset pre-computation of h in atts and han
return dict(
c_prev=c_list[:],
z_prev=z_list[:],
a_prev=a,
workspace=(att_idx, z_list, c_list),
)
def score(self, yseq, state, x):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if self.num_encs == 1:
x = [x]
att_idx, z_list, c_list = state["workspace"]
vy = yseq[-1].unsqueeze(0)
ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim
if self.num_encs == 1:
att_c, att_w = self.att[att_idx](
x[0].unsqueeze(0),
[x[0].size(0)],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"],
)
else:
att_w = [None] * (self.num_encs + 1) # atts + han
att_c_list = [None] * (self.num_encs) # atts
for idx in range(self.num_encs):
att_c_list[idx], att_w[idx] = self.att[idx](
x[idx].unsqueeze(0),
[x[idx].size(0)],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"][idx],
)
h_han = torch.stack(att_c_list, dim=1)
att_c, att_w[self.num_encs] = self.att[self.num_encs](
h_han,
[self.num_encs],
self.dropout_dec[0](state["z_prev"][0]),
state["a_prev"][self.num_encs],
)
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
z_list, c_list = self.rnn_forward(
ey, z_list, c_list, state["z_prev"], state["c_prev"]
)
if self.context_residual:
logits = self.output(
torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
)
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
logp = F.log_softmax(logits, dim=1).squeeze(0)
return (
logp,
dict(
c_prev=c_list[:],
z_prev=z_list[:],
a_prev=att_w,
workspace=(att_idx, z_list, c_list),
),
)
def decoder_for(args, odim, sos, eos, att, labeldist):
return Decoder(
args.eprojs,
odim,
args.dtype,
args.dlayers,
args.dunits,
sos,
eos,
att,
args.verbose,
args.char_list,
labeldist,
args.lsm_weight,
args.sampling_probability,
args.dropout_rate_decoder,
getattr(args, "context_residual", False), # use getattr to keep compatibility
getattr(args, "replace_sos", False), # use getattr to keep compatibility
getattr(args, "num_encs", 1),
) # use getattr to keep compatibility
import logging
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from espnet.nets.e2e_asr_common import get_vgg2l_odim
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask, to_device
class RNNP(torch.nn.Module):
"""RNN with projection layer module
:param int idim: dimension of inputs
:param int elayers: number of encoder layers
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
:param int hdim: number of projection units
:param np.ndarray subsample: list of subsampling numbers
:param float dropout: dropout rate
:param str typ: The RNN type
"""
def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"):
super(RNNP, self).__init__()
bidir = typ[0] == "b"
for i in range(elayers):
if i == 0:
inputdim = idim
else:
inputdim = hdim
RNN = torch.nn.LSTM if "lstm" in typ else torch.nn.GRU
rnn = RNN(
inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True
)
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
# bottleneck layer to merge
if bidir:
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim))
else:
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim))
self.elayers = elayers
self.cdim = cdim
self.subsample = subsample
self.typ = typ
self.bidir = bidir
self.dropout = dropout
def forward(self, xs_pad, ilens, prev_state=None):
"""RNNP forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous RNN states
:return: batch of hidden state sequences (B, Tmax, hdim)
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens))
elayer_states = []
for layer in range(self.elayers):
if not isinstance(ilens, torch.Tensor):
ilens = torch.tensor(ilens)
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True)
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
if self.training:
rnn.flatten_parameters()
if prev_state is not None and rnn.bidirectional:
prev_state = reset_backward_rnn_state(prev_state)
ys, states = rnn(
xs_pack, hx=None if prev_state is None else prev_state[layer]
)
elayer_states.append(states)
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
sub = self.subsample[layer + 1]
if sub > 1:
ys_pad = ys_pad[:, ::sub]
ilens = torch.tensor([int(i + 1) // sub for i in ilens])
# (sum _utt frame_utt) x dim
projection_layer = getattr(self, "bt%d" % layer)
projected = projection_layer(ys_pad.contiguous().view(-1, ys_pad.size(2)))
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
if layer < self.elayers - 1:
xs_pad = torch.tanh(F.dropout(xs_pad, p=self.dropout))
return xs_pad, ilens, elayer_states # x: utt list of frame x dim
class RNN(torch.nn.Module):
"""RNN module
:param int idim: dimension of inputs
:param int elayers: number of encoder layers
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
:param int hdim: number of final projection units
:param float dropout: dropout rate
:param str typ: The RNN type
"""
def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
super(RNN, self).__init__()
bidir = typ[0] == "b"
self.nbrnn = (
torch.nn.LSTM(
idim,
cdim,
elayers,
batch_first=True,
dropout=dropout,
bidirectional=bidir,
)
if "lstm" in typ
else torch.nn.GRU(
idim,
cdim,
elayers,
batch_first=True,
dropout=dropout,
bidirectional=bidir,
)
)
if bidir:
self.l_last = torch.nn.Linear(cdim * 2, hdim)
else:
self.l_last = torch.nn.Linear(cdim, hdim)
self.typ = typ
def forward(self, xs_pad, ilens, prev_state=None):
"""RNN forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous RNN states
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens))
if not isinstance(ilens, torch.Tensor):
ilens = torch.tensor(ilens)
xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True)
if self.training:
self.nbrnn.flatten_parameters()
if prev_state is not None and self.nbrnn.bidirectional:
# We assume that when previous state is passed,
# it means that we're streaming the input
# and therefore cannot propagate backward BRNN state
# (otherwise it goes in the wrong direction)
prev_state = reset_backward_rnn_state(prev_state)
ys, states = self.nbrnn(xs_pack, hx=prev_state)
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
# (sum _utt frame_utt) x dim
projected = torch.tanh(
self.l_last(ys_pad.contiguous().view(-1, ys_pad.size(2)))
)
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
return xs_pad, ilens, states # x: utt list of frame x dim
def reset_backward_rnn_state(states):
"""Sets backward BRNN states to zeroes
Useful in processing of sliding windows over the inputs
"""
if isinstance(states, (list, tuple)):
for state in states:
state[1::2] = 0.0
else:
states[1::2] = 0.0
return states
class VGG2L(torch.nn.Module):
"""VGG-like module
:param int in_channel: number of input channels
"""
def __init__(self, in_channel=1):
super(VGG2L, self).__init__()
# CNN layer (VGG motivated)
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.in_channel = in_channel
def forward(self, xs_pad, ilens, **kwargs):
"""VGG2L forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4)
:rtype: torch.Tensor
"""
logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens))
# x: utt x frame x dim
# xs_pad = F.pad_sequence(xs_pad)
# x: utt x 1 (input channel num) x frame x dim
xs_pad = xs_pad.view(
xs_pad.size(0),
xs_pad.size(1),
self.in_channel,
xs_pad.size(2) // self.in_channel,
).transpose(1, 2)
# NOTE: max_pool1d ?
xs_pad = F.relu(self.conv1_1(xs_pad))
xs_pad = F.relu(self.conv1_2(xs_pad))
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True)
xs_pad = F.relu(self.conv2_1(xs_pad))
xs_pad = F.relu(self.conv2_2(xs_pad))
xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True)
if torch.is_tensor(ilens):
ilens = ilens.cpu().numpy()
else:
ilens = np.array(ilens, dtype=np.float32)
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
ilens = np.array(
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64
).tolist()
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
xs_pad = xs_pad.transpose(1, 2)
xs_pad = xs_pad.contiguous().view(
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3)
)
return xs_pad, ilens, None # no state in this layer
class Encoder(torch.nn.Module):
"""Encoder module
:param str etype: type of encoder network
:param int idim: number of dimensions of encoder network
:param int elayers: number of layers of encoder network
:param int eunits: number of lstm units of encoder network
:param int eprojs: number of projection units of encoder network
:param np.ndarray subsample: list of subsampling numbers
:param float dropout: dropout rate
:param int in_channel: number of input channels
"""
def __init__(
self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1
):
super(Encoder, self).__init__()
typ = etype.lstrip("vgg").rstrip("p")
if typ not in ["lstm", "gru", "blstm", "bgru"]:
logging.error("Error: need to specify an appropriate encoder architecture")
if etype.startswith("vgg"):
if etype[-1] == "p":
self.enc = torch.nn.ModuleList(
[
VGG2L(in_channel),
RNNP(
get_vgg2l_odim(idim, in_channel=in_channel),
elayers,
eunits,
eprojs,
subsample,
dropout,
typ=typ,
),
]
)
logging.info("Use CNN-VGG + " + typ.upper() + "P for encoder")
else:
self.enc = torch.nn.ModuleList(
[
VGG2L(in_channel),
RNN(
get_vgg2l_odim(idim, in_channel=in_channel),
elayers,
eunits,
eprojs,
dropout,
typ=typ,
),
]
)
logging.info("Use CNN-VGG + " + typ.upper() + " for encoder")
self.conv_subsampling_factor = 4
else:
if etype[-1] == "p":
self.enc = torch.nn.ModuleList(
[RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)]
)
logging.info(typ.upper() + " with every-layer projection for encoder")
else:
self.enc = torch.nn.ModuleList(
[RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)]
)
logging.info(typ.upper() + " without projection for encoder")
self.conv_subsampling_factor = 1
def forward(self, xs_pad, ilens, prev_states=None):
"""Encoder forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
if prev_states is None:
prev_states = [None] * len(self.enc)
assert len(prev_states) == len(self.enc)
current_states = []
for module, prev_state in zip(self.enc, prev_states):
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
current_states.append(states)
# make mask to remove bias value in padded part
mask = to_device(xs_pad, make_pad_mask(ilens).unsqueeze(-1))
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
def encoder_for(args, idim, subsample):
"""Instantiates an encoder module given the program arguments
:param Namespace args: The arguments
:param int or List of integer idim: dimension of input, e.g. 83, or
List of dimensions of inputs, e.g. [83,83]
:param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or
List of subsample factors of each encoder.
e.g. [[1,2,2,1,1], [1,2,2,1,1]]
:rtype torch.nn.Module
:return: The encoder module
"""
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
if num_encs == 1:
# compatible with single encoder asr mode
return Encoder(
args.etype,
idim,
args.elayers,
args.eunits,
args.eprojs,
subsample,
args.dropout_rate,
)
elif num_encs >= 1:
enc_list = torch.nn.ModuleList()
for idx in range(num_encs):
enc = Encoder(
args.etype[idx],
idim[idx],
args.elayers[idx],
args.eunits[idx],
args.eprojs,
subsample[idx],
args.dropout_rate[idx],
)
enc_list.append(enc)
return enc_list
else:
raise ValueError(
"Number of encoders needs to be more than one. {}".format(num_encs)
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment