Commit 60a2c57a authored by sunzhq2's avatar sunzhq2 Committed by xuxo
Browse files

update conformer

parent 4a699441
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2022 Roshan Sharma (Carnegie Mellon University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Longformer based Local Attention Definition."""
from longformer.longformer import LongformerConfig, LongformerSelfAttention
from torch import nn
class LongformerAttention(nn.Module):
"""Longformer based Local Attention Definition."""
def __init__(self, config: LongformerConfig, layer_id: int):
"""Compute Longformer based Self-Attention.
Args:
config : Longformer attention configuration
layer_id: Integer representing the layer index
"""
super().__init__()
self.attention_window = config.attention_window[layer_id]
self.attention_layer = LongformerSelfAttention(config, layer_id=layer_id)
self.attention = None
def forward(self, query, key, value, mask):
"""Compute Longformer Self-Attention with masking.
Expects `len(hidden_states)` to be multiple of `attention_window`.
Padding to `attention_window` happens in :meth:`encoder.forward`
to avoid redoing the padding on each layer.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, 2*time1-1, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
attention_mask = mask.int()
attention_mask[mask == 0] = -1
attention_mask[mask == 1] = 0
output, self.attention = self.attention_layer(
hidden_states=query,
attention_mask=attention_mask.unsqueeze(1),
head_mask=None,
output_attentions=True,
)
return output
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Mask module."""
import torch
def subsequent_mask(size, device="cpu", dtype=torch.bool):
"""Create mask for subsequent steps (size, size).
:param int size: size of mask
:param str device: "cpu" or "cuda" or torch.Tensor.device
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
ret = torch.ones(size, size, device=device, dtype=dtype)
return torch.tril(ret, out=ret)
def target_mask(ys_in_pad, ignore_id):
"""Create mask for decoder self-attention.
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:param int ignore_id: index of padding
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor (B, Lmax, Lmax)
"""
ys_mask = ys_in_pad != ignore_id
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
import torch
class MultiLayeredConv1d(torch.nn.Module):
"""Multi-layered conv1d for Transformer block.
This is a module of multi-leyered conv1d designed
to replace positionwise feed-forward network
in Transforner block, which is introduced in
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
"""
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
"""Initialize MultiLayeredConv1d module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super(MultiLayeredConv1d, self).__init__()
self.w_1 = torch.nn.Conv1d(
in_chans,
hidden_chans,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
)
self.w_2 = torch.nn.Conv1d(
hidden_chans,
in_chans,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
)
self.dropout = torch.nn.Dropout(dropout_rate)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
Returns:
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
"""
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
class Conv1dLinear(torch.nn.Module):
"""Conv1D + Linear for Transformer block.
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
"""
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
"""Initialize Conv1dLinear module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super(Conv1dLinear, self).__init__()
self.w_1 = torch.nn.Conv1d(
in_chans,
hidden_chans,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
)
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
self.dropout = torch.nn.Dropout(dropout_rate)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
Returns:
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
"""
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
return self.w_2(self.dropout(x))
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Optimizer module."""
import torch
class NoamOpt(object):
"""Optim wrapper that implements rate."""
def __init__(self, model_size, factor, warmup, optimizer):
"""Construct an NoamOpt object."""
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* min(step ** (-0.5), step * self.warmup ** (-1.5))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)
def get_std_opt(model_params, d_model, warmup, factor):
"""Get standard NoamOpt."""
base = torch.optim.Adam(model_params, lr=0, betas=(0.9, 0.98), eps=1e-9)
return NoamOpt(d_model, factor, warmup, base)
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import logging
import os
import numpy
from espnet.asr import asr_utils
def _plot_and_save_attention(att_w, filename, xtokens=None, ytokens=None):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
d = os.path.dirname(filename)
if not os.path.exists(d):
os.makedirs(d)
w, h = plt.figaspect(1.0 / len(att_w))
fig = plt.Figure(figsize=(w * 2, h * 2))
axes = fig.subplots(1, len(att_w))
if len(att_w) == 1:
axes = [axes]
for ax, aw in zip(axes, att_w):
# plt.subplot(1, len(att_w), h)
ax.imshow(aw.astype(numpy.float32), aspect="auto")
ax.set_xlabel("Input")
ax.set_ylabel("Output")
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
# Labels for major ticks
if xtokens is not None:
ax.set_xticks(numpy.linspace(0, len(xtokens), len(xtokens) + 1))
ax.set_xticks(numpy.linspace(0, len(xtokens), 1), minor=True)
ax.set_xticklabels(xtokens + [""], rotation=40)
if ytokens is not None:
ax.set_yticks(numpy.linspace(0, len(ytokens), len(ytokens) + 1))
ax.set_yticks(numpy.linspace(0, len(ytokens), 1), minor=True)
ax.set_yticklabels(ytokens + [""])
fig.tight_layout()
return fig
def savefig(plot, filename):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plot.savefig(filename)
plt.clf()
def plot_multi_head_attention(
data,
uttid_list,
attn_dict,
outdir,
suffix="png",
savefn=savefig,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=4,
):
"""Plot multi head attentions.
:param dict data: utts info from json file
:param List uttid_list: utterance IDs
:param dict[str, torch.Tensor] attn_dict: multi head attention dict.
values should be torch.Tensor (head, input_length, output_length)
:param str outdir: dir to save fig
:param str suffix: filename suffix including image type (e.g., png)
:param savefn: function to save
:param str ikey: key to access input
:param int iaxis: dimension to access input
:param str okey: key to access output
:param int oaxis: dimension to access output
:param subsampling_factor: subsampling factor in encoder
"""
for name, att_ws in attn_dict.items():
for idx, att_w in enumerate(att_ws):
data_i = data[uttid_list[idx]]
filename = "%s/%s.%s.%s" % (outdir, uttid_list[idx], name, suffix)
dec_len = int(data_i[okey][oaxis]["shape"][0]) + 1 # +1 for <eos>
enc_len = int(data_i[ikey][iaxis]["shape"][0])
is_mt = "token" in data_i[ikey][iaxis].keys()
# for ASR/ST
if not is_mt:
enc_len //= subsampling_factor
xtokens, ytokens = None, None
if "encoder" in name:
att_w = att_w[:, :enc_len, :enc_len]
# for MT
if is_mt:
xtokens = data_i[ikey][iaxis]["token"].split()
ytokens = xtokens[:]
elif "decoder" in name:
if "self" in name:
# self-attention
att_w = att_w[:, :dec_len, :dec_len]
if "token" in data_i[okey][oaxis].keys():
ytokens = data_i[okey][oaxis]["token"].split() + ["<eos>"]
xtokens = ["<sos>"] + data_i[okey][oaxis]["token"].split()
else:
# cross-attention
att_w = att_w[:, :dec_len, :enc_len]
if "token" in data_i[okey][oaxis].keys():
ytokens = data_i[okey][oaxis]["token"].split() + ["<eos>"]
# for MT
if is_mt:
xtokens = data_i[ikey][iaxis]["token"].split()
else:
logging.warning("unknown name for shaping attention")
fig = _plot_and_save_attention(att_w, filename, xtokens, ytokens)
savefn(fig, filename)
class PlotAttentionReport(asr_utils.PlotAttentionReport):
def plotfn(self, *args, **kwargs):
kwargs["ikey"] = self.ikey
kwargs["iaxis"] = self.iaxis
kwargs["okey"] = self.okey
kwargs["oaxis"] = self.oaxis
kwargs["subsampling_factor"] = self.factor
plot_multi_head_attention(*args, **kwargs)
def __call__(self, trainer):
attn_dict, uttid_list = self.get_attention_weights()
suffix = "ep.{.updater.epoch}.png".format(trainer)
self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, suffix, savefig)
def get_attention_weights(self):
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
att_ws = self.att_vis_fn(*batch)
elif isinstance(batch, dict):
att_ws = self.att_vis_fn(**batch)
return att_ws, uttid_list
def log_attentions(self, logger, step):
def log_fig(plot, filename):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
logger.add_figure(os.path.basename(filename), plot, step)
plt.clf()
attn_dict, uttid_list = self.get_attention_weights()
self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, "", log_fig)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Positionwise feed forward layer definition."""
import torch
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
"""Construct an PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.w_2 = torch.nn.Linear(hidden_units, idim)
self.dropout = torch.nn.Dropout(dropout_rate)
self.activation = activation
def forward(self, x):
"""Forward function."""
return self.w_2(self.dropout(self.activation(self.w_1(x))))
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Repeat the same layer definition."""
import torch
class MultiSequential(torch.nn.Sequential):
"""Multi-input multi-output torch.nn.Sequential."""
def __init__(self, *args, layer_drop_rate=0.0):
"""Initialize MultiSequential with layer_drop.
Args:
layer_drop_rate (float): Probability of dropping out each fn (layer).
"""
super(MultiSequential, self).__init__(*args)
self.layer_drop_rate = layer_drop_rate
def forward(self, *args):
"""Repeat."""
_probs = torch.empty(len(self)).uniform_()
for idx, m in enumerate(self):
if not self.training or (_probs[idx] >= self.layer_drop_rate):
args = m(*args)
return args
def repeat(N, fn, layer_drop_rate=0.0):
"""Repeat module N times.
Args:
N (int): Number of repeat time.
fn (Callable): Function to generate module.
layer_drop_rate (float): Probability of dropping out each fn (layer).
Returns:
MultiSequential: Repeated model instance.
"""
return MultiSequential(*[fn(n) for n in range(N)], layer_drop_rate=layer_drop_rate)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Subsampling layer definition."""
import torch
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
class TooShortUttError(Exception):
"""Raised when the utt is too short for subsampling.
Args:
message (str): Message for error catch
actual_size (int): the short size that cannot pass the subsampling
limit (int): the limit size for subsampling
"""
def __init__(self, message, actual_size, limit):
"""Construct a TooShortUttError for error handler."""
super().__init__(message)
self.actual_size = actual_size
self.limit = limit
def check_short_utt(ins, size):
"""Check if the utterance is too short for subsampling."""
if isinstance(ins, Conv2dSubsampling1) and size < 5:
return True, 5
if isinstance(ins, Conv2dSubsampling2) and size < 7:
return True, 7
if isinstance(ins, Conv2dSubsampling) and size < 7:
return True, 7
if isinstance(ins, Conv2dSubsampling6) and size < 11:
return True, 11
if isinstance(ins, Conv2dSubsampling8) and size < 15:
return True, 15
return False, -1
class Conv2dSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling object."""
super(Conv2dSubsampling, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-2:2]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
class Conv2dSubsampling1(torch.nn.Module):
"""Similar to Conv2dSubsampling module, but without any subsampling performed.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling1 object."""
super(Conv2dSubsampling1, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 1),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (idim - 4), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
def forward(self, x, x_mask):
"""Pass x through 2 Conv2d layers without subsampling.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim).
where time' = time - 4.
torch.Tensor: Subsampled mask (#batch, 1, time').
where time' = time - 4.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-4]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
class Conv2dSubsampling2(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling2 object."""
super(Conv2dSubsampling2, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 1),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-2:1]
def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]
class Conv2dSubsampling6(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling6 object."""
super(Conv2dSubsampling6, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 5, 3),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-4:3]
class Conv2dSubsampling8(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling8 object."""
super(Conv2dSubsampling8, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 2),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
# Copyright 2020 Emiru Tsunoo
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Subsampling layer definition."""
import math
import torch
class Conv2dSubsamplingWOPosEnc(torch.nn.Module):
"""Convolutional 2D subsampling.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
kernels (list): kernel sizes
strides (list): stride sizes
"""
def __init__(self, idim, odim, dropout_rate, kernels, strides):
"""Construct an Conv2dSubsamplingWOPosEnc object."""
assert len(kernels) == len(strides)
super().__init__()
conv = []
olen = idim
for i, (k, s) in enumerate(zip(kernels, strides)):
conv += [
torch.nn.Conv2d(1 if i == 0 else odim, odim, k, s),
torch.nn.ReLU(),
]
olen = math.floor((olen - k) / s + 1)
self.conv = torch.nn.Sequential(*conv)
self.out = torch.nn.Linear(odim * olen, odim)
self.strides = strides
self.kernels = kernels
def forward(self, x, x_mask):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
for k, s in zip(self.kernels, self.strides):
x_mask = x_mask[:, :, : -k + 1 : s]
return x, x_mask
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi (Nagoya University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""This code is based on https://github.com/kan-bayashi/PytorchWaveNetVocoder."""
import logging
import sys
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
def encode_mu_law(x, mu=256):
"""Perform mu-law encoding.
Args:
x (ndarray): Audio signal with the range from -1 to 1.
mu (int): Quantized level.
Returns:
ndarray: Quantized audio signal with the range from 0 to mu - 1.
"""
mu = mu - 1
fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
return np.floor((fx + 1) / 2 * mu + 0.5).astype(np.int64)
def decode_mu_law(y, mu=256):
"""Perform mu-law decoding.
Args:
x (ndarray): Quantized audio signal with the range from 0 to mu - 1.
mu (int): Quantized level.
Returns:
ndarray: Audio signal with the range from -1 to 1.
"""
mu = mu - 1
fx = (y - 0.5) / mu * 2 - 1
x = np.sign(fx) / mu * ((1 + mu) ** np.abs(fx) - 1)
return x
def initialize(m):
"""Initilize conv layers with xavier.
Args:
m (torch.nn.Module): Torch module.
"""
if isinstance(m, nn.Conv1d):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.0)
if isinstance(m, nn.ConvTranspose2d):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
class OneHot(nn.Module):
"""Convert to one-hot vector.
Args:
depth (int): Dimension of one-hot vector.
"""
def __init__(self, depth):
super(OneHot, self).__init__()
self.depth = depth
def forward(self, x):
"""Calculate forward propagation.
Args:
x (LongTensor): long tensor variable with the shape (B, T)
Returns:
Tensor: float tensor variable with the shape (B, depth, T)
"""
x = x % self.depth
x = torch.unsqueeze(x, 2)
x_onehot = x.new_zeros(x.size(0), x.size(1), self.depth).float()
return x_onehot.scatter_(2, x, 1)
class CausalConv1d(nn.Module):
"""1D dilated causal convolution."""
def __init__(self, in_channels, out_channels, kernel_size, dilation=1, bias=True):
super(CausalConv1d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.dilation = dilation
self.padding = padding = (kernel_size - 1) * dilation
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
padding=padding,
dilation=dilation,
bias=bias,
)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor with the shape (B, in_channels, T).
Returns:
Tensor: Tensor with the shape (B, out_channels, T)
"""
x = self.conv(x)
if self.padding != 0:
x = x[:, :, : -self.padding]
return x
class UpSampling(nn.Module):
"""Upsampling layer with deconvolution.
Args:
upsampling_factor (int): Upsampling factor.
"""
def __init__(self, upsampling_factor, bias=True):
super(UpSampling, self).__init__()
self.upsampling_factor = upsampling_factor
self.bias = bias
self.conv = nn.ConvTranspose2d(
1,
1,
kernel_size=(1, self.upsampling_factor),
stride=(1, self.upsampling_factor),
bias=self.bias,
)
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor with the shape (B, C, T)
Returns:
Tensor: Tensor with the shape (B, C, T') where T' = T * upsampling_factor.
"""
x = x.unsqueeze(1) # B x 1 x C x T
x = self.conv(x) # B x 1 x C x T'
return x.squeeze(1)
class WaveNet(nn.Module):
"""Conditional wavenet.
Args:
n_quantize (int): Number of quantization.
n_aux (int): Number of aux feature dimension.
n_resch (int): Number of filter channels for residual block.
n_skipch (int): Number of filter channels for skip connection.
dilation_depth (int): Number of dilation depth
(e.g. if set 10, max dilation = 2^(10-1)).
dilation_repeat (int): Number of dilation repeat.
kernel_size (int): Filter size of dilated causal convolution.
upsampling_factor (int): Upsampling factor.
"""
def __init__(
self,
n_quantize=256,
n_aux=28,
n_resch=512,
n_skipch=256,
dilation_depth=10,
dilation_repeat=3,
kernel_size=2,
upsampling_factor=0,
):
super(WaveNet, self).__init__()
self.n_aux = n_aux
self.n_quantize = n_quantize
self.n_resch = n_resch
self.n_skipch = n_skipch
self.kernel_size = kernel_size
self.dilation_depth = dilation_depth
self.dilation_repeat = dilation_repeat
self.upsampling_factor = upsampling_factor
self.dilations = [
2**i for i in range(self.dilation_depth)
] * self.dilation_repeat
self.receptive_field = (self.kernel_size - 1) * sum(self.dilations) + 1
# for preprocessing
self.onehot = OneHot(self.n_quantize)
self.causal = CausalConv1d(self.n_quantize, self.n_resch, self.kernel_size)
if self.upsampling_factor > 0:
self.upsampling = UpSampling(self.upsampling_factor)
# for residual blocks
self.dil_sigmoid = nn.ModuleList()
self.dil_tanh = nn.ModuleList()
self.aux_1x1_sigmoid = nn.ModuleList()
self.aux_1x1_tanh = nn.ModuleList()
self.skip_1x1 = nn.ModuleList()
self.res_1x1 = nn.ModuleList()
for d in self.dilations:
self.dil_sigmoid += [
CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d)
]
self.dil_tanh += [
CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d)
]
self.aux_1x1_sigmoid += [nn.Conv1d(self.n_aux, self.n_resch, 1)]
self.aux_1x1_tanh += [nn.Conv1d(self.n_aux, self.n_resch, 1)]
self.skip_1x1 += [nn.Conv1d(self.n_resch, self.n_skipch, 1)]
self.res_1x1 += [nn.Conv1d(self.n_resch, self.n_resch, 1)]
# for postprocessing
self.conv_post_1 = nn.Conv1d(self.n_skipch, self.n_skipch, 1)
self.conv_post_2 = nn.Conv1d(self.n_skipch, self.n_quantize, 1)
def forward(self, x, h):
"""Calculate forward propagation.
Args:
x (LongTensor): Quantized input waveform tensor with the shape (B, T).
h (Tensor): Auxiliary feature tensor with the shape (B, n_aux, T).
Returns:
Tensor: Logits with the shape (B, T, n_quantize).
"""
# preprocess
output = self._preprocess(x)
if self.upsampling_factor > 0:
h = self.upsampling(h)
# residual block
skip_connections = []
for i in range(len(self.dilations)):
output, skip = self._residual_forward(
output,
h,
self.dil_sigmoid[i],
self.dil_tanh[i],
self.aux_1x1_sigmoid[i],
self.aux_1x1_tanh[i],
self.skip_1x1[i],
self.res_1x1[i],
)
skip_connections.append(skip)
# skip-connection part
output = sum(skip_connections)
output = self._postprocess(output)
return output
def generate(self, x, h, n_samples, interval=None, mode="sampling"):
"""Generate a waveform with fast genration algorithm.
This generation based on `Fast WaveNet Generation Algorithm`_.
Args:
x (LongTensor): Initial waveform tensor with the shape (T,).
h (Tensor): Auxiliary feature tensor with the shape (n_samples + T, n_aux).
n_samples (int): Number of samples to be generated.
interval (int, optional): Log interval.
mode (str, optional): "sampling" or "argmax".
Return:
ndarray: Generated quantized waveform (n_samples).
.. _`Fast WaveNet Generation Algorithm`: https://arxiv.org/abs/1611.09482
"""
# reshape inputs
assert len(x.shape) == 1
assert len(h.shape) == 2 and h.shape[1] == self.n_aux
x = x.unsqueeze(0)
h = h.transpose(0, 1).unsqueeze(0)
# perform upsampling
if self.upsampling_factor > 0:
h = self.upsampling(h)
# padding for shortage
if n_samples > h.shape[2]:
h = F.pad(h, (0, n_samples - h.shape[2]), "replicate")
# padding if the length less than
n_pad = self.receptive_field - x.size(1)
if n_pad > 0:
x = F.pad(x, (n_pad, 0), "constant", self.n_quantize // 2)
h = F.pad(h, (n_pad, 0), "replicate")
# prepare buffer
output = self._preprocess(x)
h_ = h[:, :, : x.size(1)]
output_buffer = []
buffer_size = []
for i, d in enumerate(self.dilations):
output, _ = self._residual_forward(
output,
h_,
self.dil_sigmoid[i],
self.dil_tanh[i],
self.aux_1x1_sigmoid[i],
self.aux_1x1_tanh[i],
self.skip_1x1[i],
self.res_1x1[i],
)
if d == 2 ** (self.dilation_depth - 1):
buffer_size.append(self.kernel_size - 1)
else:
buffer_size.append(d * 2 * (self.kernel_size - 1))
output_buffer.append(output[:, :, -buffer_size[i] - 1 : -1])
# generate
samples = x[0]
start_time = time.time()
for i in range(n_samples):
output = samples[-self.kernel_size * 2 + 1 :].unsqueeze(0)
output = self._preprocess(output)
h_ = h[:, :, samples.size(0) - 1].contiguous().view(1, self.n_aux, 1)
output_buffer_next = []
skip_connections = []
for j, d in enumerate(self.dilations):
output, skip = self._generate_residual_forward(
output,
h_,
self.dil_sigmoid[j],
self.dil_tanh[j],
self.aux_1x1_sigmoid[j],
self.aux_1x1_tanh[j],
self.skip_1x1[j],
self.res_1x1[j],
)
output = torch.cat([output_buffer[j], output], dim=2)
output_buffer_next.append(output[:, :, -buffer_size[j] :])
skip_connections.append(skip)
# update buffer
output_buffer = output_buffer_next
# get predicted sample
output = sum(skip_connections)
output = self._postprocess(output)[0]
if mode == "sampling":
posterior = F.softmax(output[-1], dim=0)
dist = torch.distributions.Categorical(posterior)
sample = dist.sample().unsqueeze(0)
elif mode == "argmax":
sample = output.argmax(-1)
else:
logging.error("mode should be sampling or argmax")
sys.exit(1)
samples = torch.cat([samples, sample], dim=0)
# show progress
if interval is not None and (i + 1) % interval == 0:
elapsed_time_per_sample = (time.time() - start_time) / interval
logging.info(
"%d/%d estimated time = %.3f sec (%.3f sec / sample)"
% (
i + 1,
n_samples,
(n_samples - i - 1) * elapsed_time_per_sample,
elapsed_time_per_sample,
)
)
start_time = time.time()
return samples[-n_samples:].cpu().numpy()
def _preprocess(self, x):
x = self.onehot(x).transpose(1, 2)
output = self.causal(x)
return output
def _postprocess(self, x):
output = F.relu(x)
output = self.conv_post_1(output)
output = F.relu(output) # B x C x T
output = self.conv_post_2(output).transpose(1, 2) # B x T x C
return output
def _residual_forward(
self,
x,
h,
dil_sigmoid,
dil_tanh,
aux_1x1_sigmoid,
aux_1x1_tanh,
skip_1x1,
res_1x1,
):
output_sigmoid = dil_sigmoid(x)
output_tanh = dil_tanh(x)
aux_output_sigmoid = aux_1x1_sigmoid(h)
aux_output_tanh = aux_1x1_tanh(h)
output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * torch.tanh(
output_tanh + aux_output_tanh
)
skip = skip_1x1(output)
output = res_1x1(output)
output = output + x
return output, skip
def _generate_residual_forward(
self,
x,
h,
dil_sigmoid,
dil_tanh,
aux_1x1_sigmoid,
aux_1x1_tanh,
skip_1x1,
res_1x1,
):
output_sigmoid = dil_sigmoid(x)[:, :, -1:]
output_tanh = dil_tanh(x)[:, :, -1:]
aux_output_sigmoid = aux_1x1_sigmoid(h)
aux_output_tanh = aux_1x1_tanh(h)
output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * torch.tanh(
output_tanh + aux_output_tanh
)
skip = skip_1x1(output)
output = res_1x1(output)
output = output + x[:, :, -1:] # B x C x 1
return output, skip
"""Scorer interface module."""
import warnings
from typing import Any, List, Tuple
import torch
class ScorerInterface:
"""Scorer interface for beam search.
The scorer performs scoring of the all tokens in vocabulary.
Examples:
* Search heuristics
* :class:`espnet.nets.scorers.length_bonus.LengthBonus`
* Decoder networks of the sequence-to-sequence models
* :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder`
* :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
* Neural language models
* :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
* :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
* :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
"""
def init_state(self, x: torch.Tensor) -> Any:
"""Get an initial state for decoding (optional).
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
return None
def select_state(self, state: Any, i: int, new_id: int = None) -> Any:
"""Select state with relative ids in the main beam search.
Args:
state: Decoder state for prefix tokens
i (int): Index to select a state in the main beam search
new_id (int): New label index to select a state if necessary
Returns:
state: pruned state
"""
return None if state is None else state[i]
def score(
self, y: torch.Tensor, state: Any, x: torch.Tensor
) -> Tuple[torch.Tensor, Any]:
"""Score new token (required).
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
scores for next token that has a shape of `(n_vocab)`
and next state for ys
"""
raise NotImplementedError
def final_score(self, state: Any) -> float:
"""Score eos (optional).
Args:
state: Scorer state for prefix tokens
Returns:
float: final score
"""
return 0.0
class BatchScorerInterface(ScorerInterface):
"""Batch scorer interface."""
def batch_init_state(self, x: torch.Tensor) -> Any:
"""Get an initial state for decoding (optional).
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
return self.init_state(x)
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
warnings.warn(
"{} batch score is implemented through for loop not parallelized".format(
self.__class__.__name__
)
)
scores = list()
outstates = list()
for i, (y, state, x) in enumerate(zip(ys, states, xs)):
score, outstate = self.score(y, state, x)
outstates.append(outstate)
scores.append(score)
scores = torch.cat(scores, 0).view(ys.shape[0], -1)
return scores, outstates
class PartialScorerInterface(ScorerInterface):
"""Partial scorer interface for beam search.
The partial scorer performs scoring when non-partial scorer finished scoring,
and receives pre-pruned next tokens to score because it is too heavy to score
all the tokens.
Examples:
* Prefix search for connectionist-temporal-classification models
* :class:`espnet.nets.scorers.ctc.CTCPrefixScorer`
"""
def score_partial(
self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor
) -> Tuple[torch.Tensor, Any]:
"""Score new token (required).
Args:
y (torch.Tensor): 1D prefix token
next_tokens (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
and next state for ys
"""
raise NotImplementedError
class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface):
"""Batch partial scorer interface for beam search."""
def batch_score_partial(
self,
ys: torch.Tensor,
next_tokens: torch.Tensor,
states: List[Any],
xs: torch.Tensor,
) -> Tuple[torch.Tensor, Any]:
"""Score new token (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)`
and next states for ys
"""
raise NotImplementedError
"""ScorerInterface implementation for CTC."""
import numpy as np
import torch
from espnet.nets.ctc_prefix_score import CTCPrefixScore, CTCPrefixScoreTH
from espnet.nets.scorer_interface import BatchPartialScorerInterface
class CTCPrefixScorer(BatchPartialScorerInterface):
"""Decoder interface wrapper for CTCPrefixScore."""
def __init__(self, ctc: torch.nn.Module, eos: int):
"""Initialize class.
Args:
ctc (torch.nn.Module): The CTC implementation.
For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
eos (int): The end-of-sequence id.
"""
self.ctc = ctc
self.eos = eos
self.impl = None
def init_state(self, x: torch.Tensor):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy()
# TODO(karita): use CTCPrefixScoreTH
self.impl = CTCPrefixScore(logp, 0, self.eos, np)
return 0, self.impl.initial_state()
def select_state(self, state, i, new_id=None):
"""Select state with relative ids in the main beam search.
Args:
state: Decoder state for prefix tokens
i (int): Index to select a state in the main beam search
new_id (int): New label id to select a state if necessary
Returns:
state: pruned state
"""
if type(state) == tuple:
if len(state) == 2: # for CTCPrefixScore
sc, st = state
return sc[i], st[i]
else: # for CTCPrefixScoreTH (need new_id > 0)
r, log_psi, f_min, f_max, scoring_idmap = state
s = log_psi[i, new_id].expand(log_psi.size(1))
if scoring_idmap is not None:
return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
else:
return r[:, :, i, new_id], s, f_min, f_max
return None if state is None else state[i]
def score_partial(self, y, ids, state, x):
"""Score new token.
Args:
y (torch.Tensor): 1D prefix token
next_tokens (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
and next state for ys
"""
prev_score, state = state
presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
tscore = torch.as_tensor(
presub_score - prev_score, device=x.device, dtype=x.dtype
)
return tscore, (presub_score, new_st)
def batch_init_state(self, x: torch.Tensor):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1
xlen = torch.tensor([logp.size(1)])
self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos)
return None
def batch_score_partial(self, y, ids, state, x):
"""Score new token.
Args:
y (torch.Tensor): 1D prefix token
ids (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
and next state for ys
"""
batch_state = (
(
torch.stack([s[0] for s in state], dim=2),
torch.stack([s[1] for s in state]),
state[0][2],
state[0][3],
)
if state[0] is not None
else None
)
return self.impl(y, batch_state, ids)
def extend_prob(self, x: torch.Tensor):
"""Extend probs for decoding.
This extension is for streaming decoding
as in Eq (14) in https://arxiv.org/abs/2006.14941
Args:
x (torch.Tensor): The encoded feature tensor
"""
logp = self.ctc.log_softmax(x.unsqueeze(0))
self.impl.extend_prob(logp)
def extend_state(self, state):
"""Extend state for decoding.
This extension is for streaming decoding
as in Eq (14) in https://arxiv.org/abs/2006.14941
Args:
state: The states of hyps
Returns: exteded state
"""
new_state = []
for s in state:
new_state.append(self.impl.extend_state(s))
return new_state
"""Length bonus module."""
from typing import Any, List, Tuple
import torch
from espnet.nets.scorer_interface import BatchScorerInterface
class LengthBonus(BatchScorerInterface):
"""Length bonus in beam search."""
def __init__(self, n_vocab: int):
"""Initialize class.
Args:
n_vocab (int): The number of tokens in vocabulary for beam search
"""
self.n = n_vocab
def score(self, y, state, x):
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (n_vocab)
and None
"""
return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
return (
torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand(
ys.shape[0], self.n
),
None,
)
"""Ngram lm implement."""
from abc import ABC
import kenlm
import torch
from espnet.nets.scorer_interface import BatchScorerInterface, PartialScorerInterface
class Ngrambase(ABC):
"""Ngram base implemented through ScorerInterface."""
def __init__(self, ngram_model, token_list):
"""Initialize Ngrambase.
Args:
ngram_model: ngram model path
token_list: token list from dict or model.json
"""
self.chardict = [x if x != "<eos>" else "</s>" for x in token_list]
self.charlen = len(self.chardict)
self.lm = kenlm.LanguageModel(ngram_model)
self.tmpkenlmstate = kenlm.State()
def init_state(self, x):
"""Initialize tmp state."""
state = kenlm.State()
self.lm.NullContextWrite(state)
return state
def score_partial_(self, y, next_token, state, x):
"""Score interface for both full and partial scorer.
Args:
y: previous char
next_token: next token need to be score
state: previous state
x: encoded feature
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
out_state = kenlm.State()
ys = self.chardict[y[-1]] if y.shape[0] > 1 else "<s>"
self.lm.BaseScore(state, ys, out_state)
scores = torch.empty_like(next_token, dtype=x.dtype, device=y.device)
for i, j in enumerate(next_token):
scores[i] = self.lm.BaseScore(
out_state, self.chardict[j], self.tmpkenlmstate
)
return scores, out_state
class NgramFullScorer(Ngrambase, BatchScorerInterface):
"""Fullscorer for ngram."""
def score(self, y, state, x):
"""Score interface for both full and partial scorer.
Args:
y: previous char
state: previous state
x: encoded feature
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
return self.score_partial_(y, torch.tensor(range(self.charlen)), state, x)
class NgramPartScorer(Ngrambase, PartialScorerInterface):
"""Partialscorer for ngram."""
def score_partial(self, y, next_token, state, x):
"""Score interface for both full and partial scorer.
Args:
y: previous char
next_token: next token need to be score
state: previous state
x: encoded feature
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
return self.score_partial_(y, next_token, state, x)
def select_state(self, state, i):
"""Empty select state for scorer interface."""
return state
"""ScorerInterface implementation for UASR."""
import numpy as np
import torch
from espnet.nets.ctc_prefix_score import CTCPrefixScore, CTCPrefixScoreTH
from espnet.nets.scorers.ctc import CTCPrefixScorer
class UASRPrefixScorer(CTCPrefixScorer):
"""Decoder interface wrapper for CTCPrefixScore."""
def __init__(self, eos: int):
"""Initialize class."""
self.eos = eos
def init_state(self, x: torch.Tensor):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
x[:, 0] = x[:, 0] - 100000000000 # simulate a no-blank CTC
self.logp = (
torch.nn.functional.log_softmax(x, dim=1).detach().squeeze(0).cpu().numpy()
)
# TODO(karita): use CTCPrefixScoreTH
self.impl = CTCPrefixScore(self.logp, 0, self.eos, np)
return 0, self.impl.initial_state()
def batch_init_state(self, x: torch.Tensor):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
x[:, 0] = x[:, 0] - 100000000000 # simulate a no-blank CTC
logp = torch.nn.functional.log_softmax(x, dim=1).unsqueeze(
0
) # assuming batch_size = 1
xlen = torch.tensor([logp.size(1)])
self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos)
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment