Commit 60a2c57a authored by sunzhq2's avatar sunzhq2 Committed by xuxo
Browse files

update conformer

parent 4a699441
"""Utility functions for Transducer models."""
import os
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from espnet.nets.pytorch_backend.nets_utils import pad_list
from espnet.nets.transducer_decoder_interface import ExtendedHypothesis, Hypothesis
def get_decoder_input(
labels: torch.Tensor, blank_id: int, ignore_id: int
) -> torch.Tensor:
"""Prepare decoder input.
Args:
labels: Label ID sequences. (B, L)
Returns:
decoder_input: Label ID sequences with blank prefix. (B, U)
"""
device = labels.device
labels_unpad = [label[label != ignore_id] for label in labels]
blank = labels[0].new([blank_id])
decoder_input = pad_list(
[torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
).to(device)
return decoder_input
def valid_aux_encoder_output_layers(
aux_layer_id: List[int],
enc_num_layers: int,
use_symm_kl_div_loss: bool,
subsample: List[int],
) -> List[int]:
"""Check whether provided auxiliary encoder layer IDs are valid.
Return the valid list sorted with duplicates removed.
Args:
aux_layer_id: Auxiliary encoder layer IDs.
enc_num_layers: Number of encoder layers.
use_symm_kl_div_loss: Whether symmetric KL divergence loss is used.
subsample: Subsampling rate per layer.
Returns:
valid: Valid list of auxiliary encoder layers.
"""
if (
not isinstance(aux_layer_id, list)
or not aux_layer_id
or not all(isinstance(layer, int) for layer in aux_layer_id)
):
raise ValueError(
"aux-transducer-loss-enc-output-layers option takes a list of layer IDs."
" Correct argument format is: '[0, 1]'"
)
sorted_list = sorted(aux_layer_id, key=int, reverse=False)
valid = list(filter(lambda x: 0 <= x < enc_num_layers, sorted_list))
if sorted_list != valid:
raise ValueError(
"Provided argument for aux-transducer-loss-enc-output-layers is incorrect."
" IDs should be between [0, %d]" % enc_num_layers
)
if use_symm_kl_div_loss:
sorted_list += [enc_num_layers]
for n in range(1, len(sorted_list)):
sub_range = subsample[(sorted_list[n - 1] + 1) : sorted_list[n] + 1]
valid_shape = [False if n > 1 else True for n in sub_range]
if False in valid_shape:
raise ValueError(
"Encoder layers %d and %d have different shape due to subsampling."
" Symmetric KL divergence loss doesn't cover such case for now."
% (sorted_list[n - 1], sorted_list[n])
)
return valid
def is_prefix(x: List[int], pref: List[int]) -> bool:
"""Check if pref is a prefix of x.
Args:
x: Label ID sequence.
pref: Prefix label ID sequence.
Returns:
: Whether pref is a prefix of x.
"""
if len(pref) >= len(x):
return False
for i in range(len(pref) - 1, -1, -1):
if pref[i] != x[i]:
return False
return True
def subtract(
x: List[ExtendedHypothesis], subset: List[ExtendedHypothesis]
) -> List[ExtendedHypothesis]:
"""Remove elements of subset if corresponding label ID sequence already exist in x.
Args:
x: Set of hypotheses.
subset: Subset of x.
Returns:
final: New set of hypotheses.
"""
final = []
for x_ in x:
if any(x_.yseq == sub.yseq for sub in subset):
continue
final.append(x_)
return final
def select_k_expansions(
hyps: List[ExtendedHypothesis],
topk_idxs: torch.Tensor,
topk_logps: torch.Tensor,
gamma: float,
) -> List[ExtendedHypothesis]:
"""Return K hypotheses candidates for expansion from a list of hypothesis.
K candidates are selected according to the extended hypotheses probabilities
and a prune-by-value method. Where K is equal to beam_size + beta.
Args:
hyps: Hypotheses.
topk_idxs: Indices of candidates hypothesis.
topk_logps: Log-probabilities for hypotheses expansions.
gamma: Allowed logp difference for prune-by-value method.
Return:
k_expansions: Best K expansion hypotheses candidates.
"""
k_expansions = []
for i, hyp in enumerate(hyps):
hyp_i = [
(int(k), hyp.score + float(v)) for k, v in zip(topk_idxs[i], topk_logps[i])
]
k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
k_expansions.append(
sorted(
filter(lambda x: (k_best_exp - gamma) <= x[1], hyp_i),
key=lambda x: x[1],
reverse=True,
)
)
return k_expansions
def select_lm_state(
lm_states: Union[List[Any], Dict[str, Any]],
idx: int,
lm_layers: int,
is_wordlm: bool,
) -> Union[List[Any], Dict[str, Any]]:
"""Get ID state from LM hidden states.
Args:
lm_states: LM hidden states.
idx: LM state ID to extract.
lm_layers: Number of LM layers.
is_wordlm: Whether provided LM is a word-level LM.
Returns:
idx_state: LM hidden state for given ID.
"""
if is_wordlm:
idx_state = lm_states[idx]
else:
idx_state = {}
idx_state["c"] = [lm_states["c"][layer][idx] for layer in range(lm_layers)]
idx_state["h"] = [lm_states["h"][layer][idx] for layer in range(lm_layers)]
return idx_state
def create_lm_batch_states(
lm_states: Union[List[Any], Dict[str, Any]], lm_layers, is_wordlm: bool
) -> Union[List[Any], Dict[str, Any]]:
"""Create LM hidden states.
Args:
lm_states: LM hidden states.
lm_layers: Number of LM layers.
is_wordlm: Whether provided LM is a word-level LM.
Returns:
new_states: LM hidden states.
"""
if is_wordlm:
return lm_states
new_states = {}
new_states["c"] = [
torch.stack([state["c"][layer] for state in lm_states])
for layer in range(lm_layers)
]
new_states["h"] = [
torch.stack([state["h"][layer] for state in lm_states])
for layer in range(lm_layers)
]
return new_states
def init_lm_state(lm_model: torch.nn.Module):
"""Initialize LM hidden states.
Args:
lm_model: LM module.
Returns:
lm_state: Initial LM hidden states.
"""
lm_layers = len(lm_model.rnn)
lm_units_typ = lm_model.typ
lm_units = lm_model.n_units
p = next(lm_model.parameters())
h = [
torch.zeros(lm_units).to(device=p.device, dtype=p.dtype)
for _ in range(lm_layers)
]
lm_state = {"h": h}
if lm_units_typ == "lstm":
lm_state["c"] = [
torch.zeros(lm_units).to(device=p.device, dtype=p.dtype)
for _ in range(lm_layers)
]
return lm_state
def recombine_hyps(hyps: List[Hypothesis]) -> List[Hypothesis]:
"""Recombine hypotheses with same label ID sequence.
Args:
hyps: Hypotheses.
Returns:
final: Recombined hypotheses.
"""
final = []
for hyp in hyps:
seq_final = [f.yseq for f in final if f.yseq]
if hyp.yseq in seq_final:
seq_pos = seq_final.index(hyp.yseq)
final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score)
else:
final.append(hyp)
return final
def pad_sequence(labels: List[int], pad_id: int) -> List[int]:
"""Left pad label ID sequences.
Args:
labels: Label ID sequence.
pad_id: Padding symbol ID.
Returns:
final: Padded label ID sequences.
"""
maxlen = max(len(x) for x in labels)
final = [([pad_id] * (maxlen - len(x))) + x for x in labels]
return final
def check_state(
state: List[Optional[torch.Tensor]], max_len: int, pad_id: int
) -> List[Optional[torch.Tensor]]:
"""Check decoder hidden states and left pad or trim if necessary.
Args:
state: Decoder hidden states. [N x (?, D_dec)]
max_len: maximum sequence length.
pad_id: Padding symbol ID.
Returns:
final: Decoder hidden states. [N x (1, max_len, D_dec)]
"""
if state is None or max_len < 1 or state[0].size(1) == max_len:
return state
curr_len = state[0].size(1)
if curr_len > max_len:
trim_val = int(state[0].size(1) - max_len)
for i, s in enumerate(state):
state[i] = s[:, trim_val:, :]
else:
layers = len(state)
ddim = state[0].size(2)
final_dims = (1, max_len, ddim)
final = [state[0].data.new(*final_dims).fill_(pad_id) for _ in range(layers)]
for i, s in enumerate(state):
final[i][:, (max_len - s.size(1)) : max_len, :] = s
return final
return state
def check_batch_states(states, max_len, pad_id):
"""Check decoder hidden states and left pad or trim if necessary.
Args:
state: Decoder hidden states. [N x (B, ?, D_dec)]
max_len: maximum sequence length.
pad_id: Padding symbol ID.
Returns:
final: Decoder hidden states. [N x (B, max_len, dec_dim)]
"""
final_dims = (len(states), max_len, states[0].size(1))
final = states[0].data.new(*final_dims).fill_(pad_id)
for i, s in enumerate(states):
curr_len = s.size(0)
if curr_len < max_len:
final[i, (max_len - curr_len) : max_len, :] = s
else:
final[i, :, :] = s[(curr_len - max_len) :, :]
return final
def custom_torch_load(model_path: str, model: torch.nn.Module, training: bool = True):
"""Load Transducer model with training-only modules and parameters removed.
Args:
model_path: Model path.
model: Transducer model.
"""
if "snapshot" in os.path.basename(model_path):
model_state_dict = torch.load(
model_path, map_location=lambda storage, loc: storage
)["model"]
else:
model_state_dict = torch.load(
model_path, map_location=lambda storage, loc: storage
)
if not training:
task_keys = ("mlp", "ctc_lin", "kl_div", "lm_lin", "error_calculator")
model_state_dict = {
k: v
for k, v in model_state_dict.items()
if not any(mod in k for mod in task_keys)
}
model.load_state_dict(model_state_dict)
del model_state_dict
"""VGG2L module definition for custom encoder."""
from typing import Tuple, Union
import torch
class VGG2L(torch.nn.Module):
"""VGG2L module for custom encoder.
Args:
idim: Input dimension.
odim: Output dimension.
pos_enc: Positional encoding class.
"""
def __init__(self, idim: int, odim: int, pos_enc: torch.nn.Module = None):
"""Construct a VGG2L object."""
super().__init__()
self.vgg2l = torch.nn.Sequential(
torch.nn.Conv2d(1, 64, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 64, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d((3, 2)),
torch.nn.Conv2d(64, 128, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d((2, 2)),
)
if pos_enc is not None:
self.output = torch.nn.Sequential(
torch.nn.Linear(128 * ((idim // 2) // 2), odim), pos_enc
)
else:
self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim)
def forward(
self, feats: torch.Tensor, feats_mask: torch.Tensor
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor],
]:
"""Forward VGG2L bottleneck.
Args:
feats: Feature sequences. (B, F, D_feats)
feats_mask: Mask of feature sequences. (B, 1, F)
Returns:
vgg_output: VGG output sequences.
(B, sub(F), D_out) or ((B, sub(F), D_out), (B, sub(F), D_att))
vgg_mask: Mask of VGG output sequences. (B, 1, sub(F))
"""
feats = feats.unsqueeze(1)
vgg_output = self.vgg2l(feats)
b, c, t, f = vgg_output.size()
vgg_output = self.output(
vgg_output.transpose(1, 2).contiguous().view(b, t, c * f)
)
if feats_mask is not None:
vgg_mask = self.create_new_mask(feats_mask)
else:
vgg_mask = feats_mask
return vgg_output, vgg_mask
def create_new_mask(self, feats_mask: torch.Tensor) -> torch.Tensor:
"""Create a subsampled mask of feature sequences.
Args:
feats_mask: Mask of feature sequences. (B, 1, F)
Returns:
vgg_mask: Mask of VGG2L output sequences. (B, 1, sub(F))
"""
vgg1_t_len = feats_mask.size(2) - (feats_mask.size(2) % 3)
vgg_mask = feats_mask[:, :, :vgg1_t_len][:, :, ::3]
vgg2_t_len = vgg_mask.size(2) - (vgg_mask.size(2) % 2)
vgg_mask = vgg_mask[:, :, :vgg2_t_len][:, :, ::2]
return vgg_mask
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Unility functions for Transformer."""
import torch
def add_sos_eos(ys_pad, sos, eos, ignore_id):
"""Add <sos> and <eos> labels.
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:param int sos: index of <sos>
:param int eos: index of <eos>
:param int ignore_id: index of padding
:return: padded tensor (B, Lmax)
:rtype: torch.Tensor
:return: padded tensor (B, Lmax)
:rtype: torch.Tensor
"""
from espnet.nets.pytorch_backend.nets_utils import pad_list
_sos = ys_pad.new([sos])
_eos = ys_pad.new([eos])
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
# Copyright 2020 Hirofumi Inaguma
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Transformer common arguments."""
from distutils.util import strtobool
def add_arguments_transformer_common(group):
"""Add Transformer common arguments."""
group.add_argument(
"--transformer-init",
type=str,
default="pytorch",
choices=[
"pytorch",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
],
help="how to initialize transformer parameters",
)
group.add_argument(
"--transformer-input-layer",
type=str,
default="conv2d",
choices=["conv2d", "linear", "embed"],
help="transformer input layer type",
)
group.add_argument(
"--transformer-attn-dropout-rate",
default=None,
type=float,
help="dropout in transformer attention. use --dropout-rate if None is set",
)
group.add_argument(
"--transformer-lr",
default=10.0,
type=float,
help="Initial value of learning rate",
)
group.add_argument(
"--transformer-warmup-steps",
default=25000,
type=int,
help="optimizer warmup steps",
)
group.add_argument(
"--transformer-length-normalized-loss",
default=True,
type=strtobool,
help="normalize loss by length",
)
group.add_argument(
"--transformer-encoder-selfattn-layer-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"rel_selfattn",
"lightconv",
"lightconv2d",
"dynamicconv",
"dynamicconv2d",
"light-dynamicconv2d",
],
help="transformer encoder self-attention layer type",
)
group.add_argument(
"--transformer-decoder-selfattn-layer-type",
type=str,
default="selfattn",
choices=[
"selfattn",
"lightconv",
"lightconv2d",
"dynamicconv",
"dynamicconv2d",
"light-dynamicconv2d",
],
help="transformer decoder self-attention layer type",
)
# Lightweight/Dynamic convolution related parameters.
# See https://arxiv.org/abs/1912.11793v2
# and https://arxiv.org/abs/1901.10430 for detail of the method.
# Configurations used in the first paper are in
# egs/{csj, librispeech}/asr1/conf/tuning/ld_conv/
group.add_argument(
"--wshare",
default=4,
type=int,
help="Number of parameter shargin for lightweight convolution",
)
group.add_argument(
"--ldconv-encoder-kernel-length",
default="21_23_25_27_29_31_33_35_37_39_41_43",
type=str,
help="kernel size for lightweight/dynamic convolution: "
'Encoder side. For example, "21_23_25" means kernel length 21 for '
"First layer, 23 for Second layer and so on.",
)
group.add_argument(
"--ldconv-decoder-kernel-length",
default="11_13_15_17_19_21",
type=str,
help="kernel size for lightweight/dynamic convolution: "
'Decoder side. For example, "21_23_25" means kernel length 21 for '
"First layer, 23 for Second layer and so on.",
)
group.add_argument(
"--ldconv-usebias",
type=strtobool,
default=False,
help="use bias term in lightweight/dynamic convolution",
)
group.add_argument(
"--dropout-rate",
default=0.0,
type=float,
help="Dropout rate for the encoder",
)
group.add_argument(
"--intermediate-ctc-weight",
default=0.0,
type=float,
help="Weight of intermediate CTC weight",
)
group.add_argument(
"--intermediate-ctc-layer",
default="",
type=str,
help="Position of intermediate CTC layer. {int} or {int},{int},...,{int}",
)
group.add_argument(
"--self-conditioning",
default=False,
type=strtobool,
help="use self-conditioning at intermediate CTC layers",
)
# Encoder
group.add_argument(
"--elayers",
default=4,
type=int,
help="Number of encoder layers (for shared recognition part "
"in multi-speaker asr mode)",
)
group.add_argument(
"--eunits",
"-u",
default=300,
type=int,
help="Number of encoder hidden units",
)
# Attention
group.add_argument(
"--adim",
default=320,
type=int,
help="Number of attention transformation dimensions",
)
group.add_argument(
"--aheads",
default=4,
type=int,
help="Number of heads for multi head attention",
)
group.add_argument(
"--stochastic-depth-rate",
default=0.0,
type=float,
help="Skip probability of stochastic layer regularization",
)
# Decoder
group.add_argument(
"--dlayers", default=1, type=int, help="Number of decoder layers"
)
group.add_argument(
"--dunits", default=320, type=int, help="Number of decoder hidden units"
)
return group
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Multi-Head Attention layer definition."""
import math
import torch
from torch import nn
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self, n_head, n_feat, dropout_rate):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
assert n_feat % n_head == 0
# We assume d_v always equals d_k
self.d_k = n_feat // n_head
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat)
self.linear_k = nn.Linear(n_feat, n_feat)
self.linear_v = nn.Linear(n_feat, n_feat)
self.linear_out = nn.Linear(n_feat, n_feat)
self.attn = None
self.dropout = nn.Dropout(p=dropout_rate)
def forward_qkv(self, query, key, value):
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
return q, k, v
def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
min_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask, min_value)
self.attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(self.attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query, key, value, mask):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask)
class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, head, time1, time2).
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k
) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask)
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
"""
def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
self.zero_triu = zero_triu
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x):
"""Compute relative positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)[
:, :, :, : x.size(-1) // 2 + 1
] # only keep the positions from 0 to time2
if self.zero_triu:
ones = torch.ones((x.size(2), x.size(3)), device=x.device)
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward(self, query, key, value, pos_emb, mask):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, 2*time1-1, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, 2*time1-1)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k
) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Emiru Tsunoo
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder self-attention layer definition."""
import torch
from torch import nn
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
class ContextualBlockEncoderLayer(nn.Module):
"""Contexutal Block Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
total_layer_num (int): Total number of layers
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn,
feed_forward,
dropout_rate,
total_layer_num,
normalize_before=True,
concat_after=False,
):
"""Construct an EncoderLayer object."""
super(ContextualBlockEncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
self.total_layer_num = total_layer_num
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
def forward(
self,
x,
mask,
infer_mode=False,
past_ctx=None,
next_ctx=None,
is_short_segment=False,
layer_idx=0,
cache=None,
):
"""Calculate forward propagation."""
if self.training or not infer_mode:
return self.forward_train(x, mask, past_ctx, next_ctx, layer_idx, cache)
else:
return self.forward_infer(
x, mask, past_ctx, next_ctx, is_short_segment, layer_idx, cache
)
def forward_train(
self, x, mask, past_ctx=None, next_ctx=None, layer_idx=0, cache=None
):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
past_ctx (torch.Tensor): Previous contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, 1, time).
cur_ctx (torch.Tensor): Current contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
layer_idx (int): layer index number
"""
nbatch = x.size(0)
nblock = x.size(1)
if past_ctx is not None:
if next_ctx is None:
# store all context vectors in one tensor
next_ctx = past_ctx.new_zeros(
nbatch, nblock, self.total_layer_num, x.size(-1)
)
else:
x[:, :, 0] = past_ctx[:, :, layer_idx]
# reshape ( nbatch, nblock, block_size + 2, dim )
# -> ( nbatch * nblock, block_size + 2, dim )
x = x.view(-1, x.size(-2), x.size(-1))
if mask is not None:
mask = mask.view(-1, mask.size(-2), mask.size(-1))
residual = x
if self.normalize_before:
x = self.norm1(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(self.self_attn(x_q, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
layer_idx += 1
# reshape ( nbatch * nblock, block_size + 2, dim )
# -> ( nbatch, nblock, block_size + 2, dim )
x = x.view(nbatch, -1, x.size(-2), x.size(-1)).squeeze(1)
if mask is not None:
mask = mask.view(nbatch, -1, mask.size(-2), mask.size(-1)).squeeze(1)
if next_ctx is not None and layer_idx < self.total_layer_num:
next_ctx[:, 0, layer_idx, :] = x[:, 0, -1, :]
next_ctx[:, 1:, layer_idx, :] = x[:, 0:-1, -1, :]
return x, mask, False, next_ctx, next_ctx, False, layer_idx
def forward_infer(
self,
x,
mask,
past_ctx=None,
next_ctx=None,
is_short_segment=False,
layer_idx=0,
cache=None,
):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
past_ctx (torch.Tensor): Previous contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, 1, time).
cur_ctx (torch.Tensor): Current contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
layer_idx (int): layer index number
"""
nbatch = x.size(0)
nblock = x.size(1)
# if layer_idx == 0, next_ctx has to be None
if layer_idx == 0:
assert next_ctx is None
next_ctx = x.new_zeros(nbatch, self.total_layer_num, x.size(-1))
# reshape ( nbatch, nblock, block_size + 2, dim )
# -> ( nbatch * nblock, block_size + 2, dim )
x = x.view(-1, x.size(-2), x.size(-1))
if mask is not None:
mask = mask.view(-1, mask.size(-2), mask.size(-1))
residual = x
if self.normalize_before:
x = self.norm1(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(self.self_attn(x_q, x, x, mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
# reshape ( nbatch * nblock, block_size + 2, dim )
# -> ( nbatch, nblock, block_size + 2, dim )
x = x.view(nbatch, nblock, x.size(-2), x.size(-1))
if mask is not None:
mask = mask.view(nbatch, nblock, mask.size(-2), mask.size(-1))
# Propagete context information (the last frame of each block)
# to the first frame
# of the next block
if not is_short_segment:
if past_ctx is None:
# First block of an utterance
x[:, 0, 0, :] = x[:, 0, -1, :]
else:
x[:, 0, 0, :] = past_ctx[:, layer_idx, :]
if nblock > 1:
x[:, 1:, 0, :] = x[:, 0:-1, -1, :]
next_ctx[:, layer_idx, :] = x[:, -1, -1, :]
else:
next_ctx = None
return x, mask, True, past_ctx, next_ctx, is_short_segment, layer_idx + 1
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Decoder definition."""
import logging
from typing import Any, List, Tuple
import torch
from espnet.nets.pytorch_backend.nets_utils import rename_state_dict
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer
from espnet.nets.pytorch_backend.transformer.dynamic_conv import DynamicConvolution
from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import DynamicConvolution2D
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.lightconv import LightweightConvolution
from espnet.nets.pytorch_backend.transformer.lightconv2d import LightweightConvolution2D
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.scorer_interface import BatchScorerInterface
def _pre_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
# https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
rename_state_dict(prefix + "output_norm.", prefix + "after_norm.", state_dict)
class Decoder(BatchScorerInterface, torch.nn.Module):
"""Transfomer decoder module.
Args:
odim (int): Output diminsion.
self_attention_layer_type (str): Self-attention layer type.
attention_dim (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
conv_wshare (int): The number of kernel of convolution. Only used in
self_attention_layer_type == "lightconv*" or "dynamiconv*".
conv_kernel_length (Union[int, str]): Kernel size str of convolution
(e.g. 71_71_71_71_71_71). Only used in self_attention_layer_type
== "lightconv*" or "dynamiconv*".
conv_usebias (bool): Whether to use bias in convolution. Only used in
self_attention_layer_type == "lightconv*" or "dynamiconv*".
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
self_attention_dropout_rate (float): Dropout rate in self-attention.
src_attention_dropout_rate (float): Dropout rate in source-attention.
input_layer (Union[str, torch.nn.Module]): Input layer type.
use_output_layer (bool): Whether to use output layer.
pos_enc_class (torch.nn.Module): Positional encoding module class.
`PositionalEncoding `or `ScaledPositionalEncoding`
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
odim,
selfattention_layer_type="selfattn",
attention_dim=256,
attention_heads=4,
conv_wshare=4,
conv_kernel_length=11,
conv_usebias=False,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
self_attention_dropout_rate=0.0,
src_attention_dropout_rate=0.0,
input_layer="embed",
use_output_layer=True,
pos_enc_class=PositionalEncoding,
normalize_before=True,
concat_after=False,
):
"""Construct an Decoder object."""
torch.nn.Module.__init__(self)
self._register_load_state_dict_pre_hook(_pre_hook)
if input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(odim, attention_dim),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(odim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer, pos_enc_class(attention_dim, positional_dropout_rate)
)
else:
raise NotImplementedError("only `embed` or torch.nn.Module is supported.")
self.normalize_before = normalize_before
# self-attention module definition
if selfattention_layer_type == "selfattn":
logging.info("decoder self-attention layer type = self-attention")
decoder_selfattn_layer = MultiHeadedAttention
decoder_selfattn_layer_args = [
(
attention_heads,
attention_dim,
self_attention_dropout_rate,
)
] * num_blocks
elif selfattention_layer_type == "lightconv":
logging.info("decoder self-attention layer type = lightweight convolution")
decoder_selfattn_layer = LightweightConvolution
decoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
self_attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
True,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "lightconv2d":
logging.info(
"decoder self-attention layer "
"type = lightweight convolution 2-dimensional"
)
decoder_selfattn_layer = LightweightConvolution2D
decoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
self_attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
True,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "dynamicconv":
logging.info("decoder self-attention layer type = dynamic convolution")
decoder_selfattn_layer = DynamicConvolution
decoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
self_attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
True,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "dynamicconv2d":
logging.info(
"decoder self-attention layer type = dynamic convolution 2-dimensional"
)
decoder_selfattn_layer = DynamicConvolution2D
decoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
self_attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
True,
conv_usebias,
)
for lnum in range(num_blocks)
]
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
decoder_selfattn_layer(*decoder_selfattn_layer_args[lnum]),
MultiHeadedAttention(
attention_heads, attention_dim, src_attention_dropout_rate
),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
normalize_before,
concat_after,
),
)
self.selfattention_layer_type = selfattention_layer_type
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
if use_output_layer:
self.output_layer = torch.nn.Linear(attention_dim, odim)
else:
self.output_layer = None
def forward(self, tgt, tgt_mask, memory, memory_mask):
"""Forward decoder.
Args:
tgt (torch.Tensor): Input token ids, int64 (#batch, maxlen_out) if
input_layer == "embed". In the other case, input tensor
(#batch, maxlen_out, odim).
tgt_mask (torch.Tensor): Input token mask (#batch, maxlen_out).
dtype=torch.uint8 in PyTorch 1.2- and dtype=torch.bool in PyTorch 1.2+
(include 1.2).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, feat).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
dtype=torch.uint8 in PyTorch 1.2- and dtype=torch.bool in PyTorch 1.2+
(include 1.2).
Returns:
torch.Tensor: Decoded token score before softmax (#batch, maxlen_out, odim)
if use_output_layer is True. In the other case,final block outputs
(#batch, maxlen_out, attention_dim).
torch.Tensor: Score mask before softmax (#batch, maxlen_out).
"""
x = self.embed(tgt)
x, tgt_mask, memory, memory_mask = self.decoders(
x, tgt_mask, memory, memory_mask
)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
x = self.output_layer(x)
return x, tgt_mask
def forward_one_step(self, tgt, tgt_mask, memory, cache=None):
"""Forward one step.
Args:
tgt (torch.Tensor): Input token ids, int64 (#batch, maxlen_out).
tgt_mask (torch.Tensor): Input token mask (#batch, maxlen_out).
dtype=torch.uint8 in PyTorch 1.2- and dtype=torch.bool in PyTorch 1.2+
(include 1.2).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, feat).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor (batch, maxlen_out, odim).
List[torch.Tensor]: List of cache tensors of each decoder layer.
"""
x = self.embed(tgt)
if cache is None:
cache = [None] * len(self.decoders)
new_cache = []
for c, decoder in zip(cache, self.decoders):
x, tgt_mask, memory, memory_mask = decoder(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(x)
if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.output_layer is not None:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
# beam search API (see ScorerInterface)
def score(self, ys, state, x):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
if self.selfattention_layer_type != "selfattn":
# TODO(karita): implement cache
logging.warning(
f"{self.selfattention_layer_type} does not support cached decoding."
)
state = None
logp, state = self.forward_one_step(
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
)
return logp.squeeze(0), state
# batch beam search API (see BatchScorerInterface)
def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
n_layers = len(self.decoders)
if states[0] is None:
batch_state = None
else:
# transpose state of [batch, layer] into [layer, batch]
batch_state = [
torch.stack([states[b][i] for b in range(n_batch)])
for i in range(n_layers)
]
# batch decoding
ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
return logp, state_list
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Decoder self-attention layer definition."""
import torch
from torch import nn
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
class DecoderLayer(nn.Module):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn,
src_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.norm3 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
cache (List[torch.Tensor]): List of cached tensors.
Each tensor shape should be (#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = None
if tgt_mask is not None:
tgt_q_mask = tgt_mask[:, -1:, :]
if self.concat_after:
tgt_concat = torch.cat(
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
)
x = residual + self.concat_linear1(tgt_concat)
else:
x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
if self.concat_after:
x_concat = torch.cat(
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
)
x = residual + self.concat_linear2(x_concat)
else:
x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
if not self.normalize_before:
x = self.norm2(x)
residual = x
if self.normalize_before:
x = self.norm3(x)
x = residual + self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm3(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, tgt_mask, memory, memory_mask
"""Dynamic Convolution module."""
import numpy
import torch
import torch.nn.functional as F
from torch import nn
MIN_VALUE = float(numpy.finfo(numpy.float32).min)
class DynamicConvolution(nn.Module):
"""Dynamic Convolution layer.
This implementation is based on
https://github.com/pytorch/fairseq/tree/master/fairseq
Args:
wshare (int): the number of kernel of convolution
n_feat (int): the number of features
dropout_rate (float): dropout_rate
kernel_size (int): kernel size (length)
use_kernel_mask (bool): Use causal mask or not for convolution kernel
use_bias (bool): Use bias term or not.
"""
def __init__(
self,
wshare,
n_feat,
dropout_rate,
kernel_size,
use_kernel_mask=False,
use_bias=False,
):
"""Construct Dynamic Convolution layer."""
super(DynamicConvolution, self).__init__()
assert n_feat % wshare == 0
self.wshare = wshare
self.use_kernel_mask = use_kernel_mask
self.dropout_rate = dropout_rate
self.kernel_size = kernel_size
self.attn = None
# linear -> GLU -- -> lightconv -> linear
# \ /
# Linear
self.linear1 = nn.Linear(n_feat, n_feat * 2)
self.linear2 = nn.Linear(n_feat, n_feat)
self.linear_weight = nn.Linear(n_feat, self.wshare * 1 * kernel_size)
nn.init.xavier_uniform(self.linear_weight.weight)
self.act = nn.GLU()
# dynamic conv related
self.use_bias = use_bias
if self.use_bias:
self.bias = nn.Parameter(torch.Tensor(n_feat))
def forward(self, query, key, value, mask):
"""Forward of 'Dynamic Convolution'.
This function takes query, key and value but uses only quert.
This is just for compatibility with self-attention layer (attention.py)
Args:
query (torch.Tensor): (batch, time1, d_model) input tensor
key (torch.Tensor): (batch, time2, d_model) NOT USED
value (torch.Tensor): (batch, time2, d_model) NOT USED
mask (torch.Tensor): (batch, time1, time2) mask
Return:
x (torch.Tensor): (batch, time1, d_model) output
"""
# linear -> GLU -- -> lightconv -> linear
# \ /
# Linear
x = query
B, T, C = x.size()
H = self.wshare
k = self.kernel_size
# first liner layer
x = self.linear1(x)
# GLU activation
x = self.act(x)
# get kernel of convolution
weight = self.linear_weight(x) # B x T x kH
weight = F.dropout(weight, self.dropout_rate, training=self.training)
weight = weight.view(B, T, H, k).transpose(1, 2).contiguous() # B x H x T x k
weight_new = torch.zeros(B * H * T * (T + k - 1), dtype=weight.dtype)
weight_new = weight_new.view(B, H, T, T + k - 1).fill_(float("-inf"))
weight_new = weight_new.to(x.device) # B x H x T x T+k-1
weight_new.as_strided(
(B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1)
).copy_(weight)
weight_new = weight_new.narrow(-1, int((k - 1) / 2), T) # B x H x T x T(k)
if self.use_kernel_mask:
kernel_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0)
weight_new = weight_new.masked_fill(kernel_mask == 0.0, float("-inf"))
weight_new = F.softmax(weight_new, dim=-1)
self.attn = weight_new
weight_new = weight_new.view(B * H, T, T)
# convolution
x = x.transpose(1, 2).contiguous() # B x C x T
x = x.view(B * H, int(C / H), T).transpose(1, 2)
x = torch.bmm(weight_new, x) # BH x T x C/H
x = x.transpose(1, 2).contiguous().view(B, C, T)
if self.use_bias:
x = x + self.bias.view(1, -1, 1)
x = x.transpose(1, 2) # B x T x C
if mask is not None and not self.use_kernel_mask:
mask = mask.transpose(-1, -2)
x = x.masked_fill(mask == 0, 0.0)
# second linear layer
x = self.linear2(x)
return x
"""Dynamic 2-Dimensional Convolution module."""
import numpy
import torch
import torch.nn.functional as F
from torch import nn
MIN_VALUE = float(numpy.finfo(numpy.float32).min)
class DynamicConvolution2D(nn.Module):
"""Dynamic 2-Dimensional Convolution layer.
This implementation is based on
https://github.com/pytorch/fairseq/tree/master/fairseq
Args:
wshare (int): the number of kernel of convolution
n_feat (int): the number of features
dropout_rate (float): dropout_rate
kernel_size (int): kernel size (length)
use_kernel_mask (bool): Use causal mask or not for convolution kernel
use_bias (bool): Use bias term or not.
"""
def __init__(
self,
wshare,
n_feat,
dropout_rate,
kernel_size,
use_kernel_mask=False,
use_bias=False,
):
"""Construct Dynamic 2-Dimensional Convolution layer."""
super(DynamicConvolution2D, self).__init__()
assert n_feat % wshare == 0
self.wshare = wshare
self.use_kernel_mask = use_kernel_mask
self.dropout_rate = dropout_rate
self.kernel_size = kernel_size
self.padding_size = int(kernel_size / 2)
self.attn_t = None
self.attn_f = None
# linear -> GLU -- -> lightconv -> linear
# \ /
# Linear
self.linear1 = nn.Linear(n_feat, n_feat * 2)
self.linear2 = nn.Linear(n_feat * 2, n_feat)
self.linear_weight = nn.Linear(n_feat, self.wshare * 1 * kernel_size)
nn.init.xavier_uniform(self.linear_weight.weight)
self.linear_weight_f = nn.Linear(n_feat, kernel_size)
nn.init.xavier_uniform(self.linear_weight_f.weight)
self.act = nn.GLU()
# dynamic conv related
self.use_bias = use_bias
if self.use_bias:
self.bias = nn.Parameter(torch.Tensor(n_feat))
def forward(self, query, key, value, mask):
"""Forward of 'Dynamic 2-Dimensional Convolution'.
This function takes query, key and value but uses only query.
This is just for compatibility with self-attention layer (attention.py)
Args:
query (torch.Tensor): (batch, time1, d_model) input tensor
key (torch.Tensor): (batch, time2, d_model) NOT USED
value (torch.Tensor): (batch, time2, d_model) NOT USED
mask (torch.Tensor): (batch, time1, time2) mask
Return:
x (torch.Tensor): (batch, time1, d_model) output
"""
# linear -> GLU -- -> lightconv -> linear
# \ /
# Linear
x = query
B, T, C = x.size()
H = self.wshare
k = self.kernel_size
# first liner layer
x = self.linear1(x)
# GLU activation
x = self.act(x)
# convolution of frequency axis
weight_f = self.linear_weight_f(x).view(B * T, 1, k) # B x T x k
self.attn_f = weight_f.view(B, T, k).unsqueeze(1)
xf = F.conv1d(
x.view(1, B * T, C), weight_f, padding=self.padding_size, groups=B * T
)
xf = xf.view(B, T, C)
# get kernel of convolution
weight = self.linear_weight(x) # B x T x kH
weight = F.dropout(weight, self.dropout_rate, training=self.training)
weight = weight.view(B, T, H, k).transpose(1, 2).contiguous() # B x H x T x k
weight_new = torch.zeros(B * H * T * (T + k - 1), dtype=weight.dtype)
weight_new = weight_new.view(B, H, T, T + k - 1).fill_(float("-inf"))
weight_new = weight_new.to(x.device) # B x H x T x T+k-1
weight_new.as_strided(
(B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1)
).copy_(weight)
weight_new = weight_new.narrow(-1, int((k - 1) / 2), T) # B x H x T x T(k)
if self.use_kernel_mask:
kernel_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0)
weight_new = weight_new.masked_fill(kernel_mask == 0.0, float("-inf"))
weight_new = F.softmax(weight_new, dim=-1)
self.attn_t = weight_new
weight_new = weight_new.view(B * H, T, T)
# convolution
x = x.transpose(1, 2).contiguous() # B x C x T
x = x.view(B * H, int(C / H), T).transpose(1, 2)
x = torch.bmm(weight_new, x)
x = x.transpose(1, 2).contiguous().view(B, C, T)
if self.use_bias:
x = x + self.bias.view(1, -1, 1)
x = x.transpose(1, 2) # B x T x C
x = torch.cat((x, xf), -1) # B x T x Cx2
if mask is not None and not self.use_kernel_mask:
mask = mask.transpose(-1, -2)
x = x.masked_fill(mask == 0, 0.0)
# second linear layer
x = self.linear2(x)
return x
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Positional Encoding Module."""
import math
import torch
def _pre_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""Perform pre-hook in load_state_dict for backward compatibility.
Note:
We saved self.pe until v.0.5.2 but we have omitted it later.
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
"""
k = prefix + "pe"
if k in state_dict:
state_dict.pop(k)
class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
reverse (bool): Whether to reverse the input position. Only for
the class LegacyRelPositionalEncoding. We remove it in the current
class RelPositionalEncoding.
"""
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
self.reverse = reverse
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
self._register_load_state_dict_pre_hook(_pre_hook)
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1)]
return self.dropout(x)
class ScaledPositionalEncoding(PositionalEncoding):
"""Scaled positional encoding module.
See Sec. 3.2 https://arxiv.org/abs/1809.08895
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Initialize class."""
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
def reset_parameters(self):
"""Reset parameters."""
self.alpha.data = torch.tensor(1.0)
def forward(self, x):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x + self.alpha * self.pe[:, : x.size(1)]
return self.dropout(x)
class LearnableFourierPosEnc(torch.nn.Module):
"""Learnable Fourier Features for Positional Encoding.
See https://arxiv.org/pdf/2106.02795.pdf
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
gamma (float): init parameter for the positional kernel variance
see https://arxiv.org/pdf/2106.02795.pdf.
apply_scaling (bool): Whether to scale the input before adding the pos encoding.
hidden_dim (int): if not None, we modulate the pos encodings with
an MLP whose hidden layer has hidden_dim neurons.
"""
def __init__(
self,
d_model,
dropout_rate=0.0,
max_len=5000,
gamma=1.0,
apply_scaling=False,
hidden_dim=None,
):
"""Initialize class."""
super(LearnableFourierPosEnc, self).__init__()
self.d_model = d_model
if apply_scaling:
self.xscale = math.sqrt(self.d_model)
else:
self.xscale = 1.0
self.dropout = torch.nn.Dropout(dropout_rate)
self.max_len = max_len
self.gamma = gamma
if self.gamma is None:
self.gamma = self.d_model // 2
assert (
d_model % 2 == 0
), "d_model should be divisible by two in order to use this layer."
self.w_r = torch.nn.Parameter(torch.empty(1, d_model // 2))
self._reset() # init the weights
self.hidden_dim = hidden_dim
if self.hidden_dim is not None:
self.mlp = torch.nn.Sequential(
torch.nn.Linear(d_model, hidden_dim),
torch.nn.GELU(),
torch.nn.Linear(hidden_dim, d_model),
)
def _reset(self):
self.w_r.data = torch.normal(
0, (1 / math.sqrt(self.gamma)), (1, self.d_model // 2)
)
def extend_pe(self, x):
"""Reset the positional encodings."""
position_v = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1).to(x)
cosine = torch.cos(torch.matmul(position_v, self.w_r))
sine = torch.sin(torch.matmul(position_v, self.w_r))
pos_enc = torch.cat((cosine, sine), -1)
pos_enc /= math.sqrt(self.d_model)
if self.hidden_dim is None:
return pos_enc.unsqueeze(0)
else:
return self.mlp(pos_enc.unsqueeze(0))
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
pe = self.extend_pe(x)
x = x * self.xscale + pe
return self.dropout(x)
class LegacyRelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding module (old version).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Initialize class."""
super().__init__(
d_model=d_model,
dropout_rate=dropout_rate,
max_len=max_len,
reverse=True,
)
def forward(self, x):
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[:, : x.size(1)]
return self.dropout(x), self.dropout(pos_emb)
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module (new implementation).
Details can be found in https://github.com/espnet/espnet/pull/2816.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Construct an PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in https://arxiv.org/abs/1901.02860
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[
:,
self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
]
return self.dropout(x), self.dropout(pos_emb)
class StreamPositionalEncoding(torch.nn.Module):
"""Streaming Positional encoding.
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Construct an PositionalEncoding object."""
super(StreamPositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.tmp = torch.tensor(0.0).expand(1, max_len)
self.extend_pe(self.tmp.size(1), self.tmp.device, self.tmp.dtype)
self._register_load_state_dict_pre_hook(_pre_hook)
def extend_pe(self, length, device, dtype):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= length:
if self.pe.dtype != dtype or self.pe.device != device:
self.pe = self.pe.to(dtype=dtype, device=device)
return
pe = torch.zeros(length, self.d_model)
position = torch.arange(0, length, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=device, dtype=dtype)
def forward(self, x: torch.Tensor, start_idx: int = 0):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x.size(1) + start_idx, x.device, x.dtype)
x = x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
return self.dropout(x)
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder definition."""
import logging
import torch
from espnet.nets.pytorch_backend.nets_utils import rename_state_dict
from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.dynamic_conv import DynamicConvolution
from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import DynamicConvolution2D
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
from espnet.nets.pytorch_backend.transformer.lightconv import LightweightConvolution
from espnet.nets.pytorch_backend.transformer.lightconv2d import LightweightConvolution2D
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import (
Conv1dLinear,
MultiLayeredConv1d,
)
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling6,
Conv2dSubsampling8,
)
def _pre_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
# https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563
rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
# https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
class Encoder(torch.nn.Module):
"""Transformer encoder module.
Args:
idim (int): Input dimension.
attention_dim (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
conv_wshare (int): The number of kernel of convolution. Only used in
selfattention_layer_type == "lightconv*" or "dynamiconv*".
conv_kernel_length (Union[int, str]): Kernel size str of convolution
(e.g. 71_71_71_71_71_71). Only used in selfattention_layer_type
== "lightconv*" or "dynamiconv*".
conv_usebias (bool): Whether to use bias in convolution. Only used in
selfattention_layer_type == "lightconv*" or "dynamiconv*".
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
attention_dropout_rate (float): Dropout rate in attention.
input_layer (Union[str, torch.nn.Module]): Input layer type.
pos_enc_class (torch.nn.Module): Positional encoding module class.
`PositionalEncoding `or `ScaledPositionalEncoding`
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
selfattention_layer_type (str): Encoder attention layer type.
padding_idx (int): Padding idx for input_layer=embed.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
indices start from 1.
if not None, intermediate outputs are returned (which changes return type
signature.)
"""
def __init__(
self,
idim,
attention_dim=256,
attention_heads=4,
conv_wshare=4,
conv_kernel_length="11",
conv_usebias=False,
linear_units=2048,
num_blocks=6,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
pos_enc_class=PositionalEncoding,
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
selfattention_layer_type="selfattn",
padding_idx=-1,
stochastic_depth_rate=0.0,
intermediate_layers=None,
ctc_softmax=None,
conditioning_layer_dim=None,
):
"""Construct an Encoder object."""
super(Encoder, self).__init__()
self._register_load_state_dict_pre_hook(_pre_hook)
self.conv_subsampling_factor = 1
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(idim, attention_dim),
torch.nn.LayerNorm(attention_dim),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 4
elif input_layer == "conv2d-scaled-pos-enc":
self.embed = Conv2dSubsampling(
idim,
attention_dim,
dropout_rate,
pos_enc_class(attention_dim, positional_dropout_rate),
)
self.conv_subsampling_factor = 4
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 6
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(idim, attention_dim, dropout_rate)
self.conv_subsampling_factor = 8
elif input_layer == "vgg2l":
self.embed = VGG2L(idim, attention_dim)
self.conv_subsampling_factor = 4
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(attention_dim, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(attention_dim, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
positionwise_layer, positionwise_layer_args = self.get_positionwise_layer(
positionwise_layer_type,
attention_dim,
linear_units,
dropout_rate,
positionwise_conv_kernel_size,
)
if selfattention_layer_type in [
"selfattn",
"rel_selfattn",
"legacy_rel_selfattn",
]:
logging.info("encoder self-attention layer type = self-attention")
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = [
(
attention_heads,
attention_dim,
attention_dropout_rate,
)
] * num_blocks
elif selfattention_layer_type == "lightconv":
logging.info("encoder self-attention layer type = lightweight convolution")
encoder_selfattn_layer = LightweightConvolution
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "lightconv2d":
logging.info(
"encoder self-attention layer "
"type = lightweight convolution 2-dimensional"
)
encoder_selfattn_layer = LightweightConvolution2D
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "dynamicconv":
logging.info("encoder self-attention layer type = dynamic convolution")
encoder_selfattn_layer = DynamicConvolution
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
elif selfattention_layer_type == "dynamicconv2d":
logging.info(
"encoder self-attention layer type = dynamic convolution 2-dimensional"
)
encoder_selfattn_layer = DynamicConvolution2D
encoder_selfattn_layer_args = [
(
conv_wshare,
attention_dim,
attention_dropout_rate,
int(conv_kernel_length.split("_")[lnum]),
False,
conv_usebias,
)
for lnum in range(num_blocks)
]
else:
raise NotImplementedError(selfattention_layer_type)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
attention_dim,
encoder_selfattn_layer(*encoder_selfattn_layer_args[lnum]),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
stochastic_depth_rate * float(1 + lnum) / num_blocks,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(attention_dim)
self.intermediate_layers = intermediate_layers
self.use_conditioning = True if ctc_softmax is not None else False
if self.use_conditioning:
self.ctc_softmax = ctc_softmax
self.conditioning_layer = torch.nn.Linear(
conditioning_layer_dim, attention_dim
)
def get_positionwise_layer(
self,
positionwise_layer_type="linear",
attention_dim=256,
linear_units=2048,
dropout_rate=0.1,
positionwise_conv_kernel_size=1,
):
"""Define positionwise layer."""
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
attention_dim,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
return positionwise_layer, positionwise_layer_args
def forward(self, xs, masks):
"""Encode input sequence.
Args:
xs (torch.Tensor): Input tensor (#batch, time, idim).
masks (torch.Tensor): Mask tensor (#batch, 1, time).
Returns:
torch.Tensor: Output tensor (#batch, time, attention_dim).
torch.Tensor: Mask tensor (#batch, 1, time).
"""
if isinstance(
self.embed,
(Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L),
):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
if self.intermediate_layers is None:
xs, masks = self.encoders(xs, masks)
else:
intermediate_outputs = []
for layer_idx, encoder_layer in enumerate(self.encoders):
xs, masks = encoder_layer(xs, masks)
if (
self.intermediate_layers is not None
and layer_idx + 1 in self.intermediate_layers
):
encoder_output = xs
# intermediate branches also require normalization.
if self.normalize_before:
encoder_output = self.after_norm(encoder_output)
intermediate_outputs.append(encoder_output)
if self.use_conditioning:
intermediate_result = self.ctc_softmax(encoder_output)
xs = xs + self.conditioning_layer(intermediate_result)
if self.normalize_before:
xs = self.after_norm(xs)
if self.intermediate_layers is not None:
return xs, masks, intermediate_outputs
return xs, masks
def forward_one_step(self, xs, masks, cache=None):
"""Encode input frame.
Args:
xs (torch.Tensor): Input tensor.
masks (torch.Tensor): Mask tensor.
cache (List[torch.Tensor]): List of cache tensors.
Returns:
torch.Tensor: Output tensor.
torch.Tensor: Mask tensor.
List[torch.Tensor]: List of new cache tensors.
"""
if isinstance(self.embed, Conv2dSubsampling):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
if cache is None:
cache = [None for _ in range(len(self.encoders))]
new_cache = []
for c, e in zip(cache, self.encoders):
xs, masks = e(xs, masks, cache=c)
new_cache.append(xs)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks, new_cache
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder self-attention layer definition."""
import torch
from torch import nn
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
class EncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
stochastic_depth_rate (float): Proability to skip this layer.
During training, the layer may skip residual computation and return input
as-is with given probability.
"""
def __init__(
self,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
def forward(self, x, mask, cache=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, 1, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x_q, x, x, mask)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder Mix definition."""
import torch
from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
from espnet.nets.pytorch_backend.transformer.encoder import Encoder
from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
class EncoderMix(Encoder, torch.nn.Module):
"""Transformer encoder module.
:param int idim: input dim
:param int attention_dim: dimension of attention
:param int attention_heads: the number of heads of multi head attention
:param int linear_units: the number of units of position-wise feed forward
:param int num_blocks: the number of decoder blocks
:param float dropout_rate: dropout rate
:param float attention_dropout_rate: dropout rate in attention
:param float positional_dropout_rate: dropout rate after adding positional encoding
:param str or torch.nn.Module input_layer: input layer type
:param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
:param bool normalize_before: whether to use layer_norm before the first block
:param bool concat_after: whether to concat attention layer's input and output
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
:param str positionwise_layer_type: linear of conv1d
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
:param int padding_idx: padding_idx for input_layer=embed
"""
def __init__(
self,
idim,
attention_dim=256,
attention_heads=4,
linear_units=2048,
num_blocks_sd=4,
num_blocks_rec=8,
dropout_rate=0.1,
positional_dropout_rate=0.1,
attention_dropout_rate=0.0,
input_layer="conv2d",
pos_enc_class=PositionalEncoding,
normalize_before=True,
concat_after=False,
positionwise_layer_type="linear",
positionwise_conv_kernel_size=1,
padding_idx=-1,
num_spkrs=2,
):
"""Construct an Encoder object."""
super(EncoderMix, self).__init__(
idim=idim,
selfattention_layer_type="selfattn",
attention_dim=attention_dim,
attention_heads=attention_heads,
linear_units=linear_units,
num_blocks=num_blocks_rec,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
input_layer=input_layer,
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
concat_after=concat_after,
positionwise_layer_type=positionwise_layer_type,
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
padding_idx=padding_idx,
)
positionwise_layer, positionwise_layer_args = self.get_positionwise_layer(
positionwise_layer_type,
attention_dim,
linear_units,
dropout_rate,
positionwise_conv_kernel_size,
)
self.num_spkrs = num_spkrs
self.encoders_sd = torch.nn.ModuleList(
[
repeat(
num_blocks_sd,
lambda lnum: EncoderLayer(
attention_dim,
MultiHeadedAttention(
attention_heads, attention_dim, attention_dropout_rate
),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
for i in range(num_spkrs)
]
)
def forward(self, xs, masks):
"""Encode input sequence.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:return: position embedded tensor and mask
:rtype Tuple[torch.Tensor, torch.Tensor]:
"""
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
xs_sd, masks_sd = [None] * self.num_spkrs, [None] * self.num_spkrs
for ns in range(self.num_spkrs):
xs_sd[ns], masks_sd[ns] = self.encoders_sd[ns](xs, masks)
xs_sd[ns], masks_sd[ns] = self.encoders(xs_sd[ns], masks_sd[ns]) # Enc_rec
if self.normalize_before:
xs_sd[ns] = self.after_norm(xs_sd[ns])
return xs_sd, masks_sd
def forward_one_step(self, xs, masks, cache=None):
"""Encode input frame.
:param torch.Tensor xs: input tensor
:param torch.Tensor masks: input mask
:param List[torch.Tensor] cache: cache tensors
:return: position embedded tensor, mask and new cache
:rtype Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
if isinstance(self.embed, Conv2dSubsampling):
xs, masks = self.embed(xs, masks)
else:
xs = self.embed(xs)
new_cache_sd = []
for ns in range(self.num_spkrs):
if cache is None:
cache = [
None for _ in range(len(self.encoders_sd) + len(self.encoders_rec))
]
new_cache = []
for c, e in zip(cache[: len(self.encoders_sd)], self.encoders_sd[ns]):
xs, masks = e(xs, masks, cache=c)
new_cache.append(xs)
for c, e in zip(cache[: len(self.encoders_sd) :], self.encoders_rec):
xs, masks = e(xs, masks, cache=c)
new_cache.append(xs)
new_cache_sd.append(new_cache)
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks, new_cache_sd
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Parameter initialization."""
import torch
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
def initialize(model, init_type="pytorch"):
"""Initialize Transformer module.
:param torch.nn.Module model: transformer instance
:param str init_type: initialization type
"""
if init_type == "pytorch":
return
# weight init
for p in model.parameters():
if p.dim() > 1:
if init_type == "xavier_uniform":
torch.nn.init.xavier_uniform_(p.data)
elif init_type == "xavier_normal":
torch.nn.init.xavier_normal_(p.data)
elif init_type == "kaiming_uniform":
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
elif init_type == "kaiming_normal":
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
else:
raise ValueError("Unknown initialization: " + init_type)
# bias init
for p in model.parameters():
if p.dim() == 1:
p.data.zero_()
# reset some modules with default init
for m in model.modules():
if isinstance(m, (torch.nn.Embedding, LayerNorm)):
m.reset_parameters()
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Label smoothing module."""
import torch
from torch import nn
class LabelSmoothingLoss(nn.Module):
"""Label-smoothing loss.
:param int size: the number of class
:param int padding_idx: ignored class id
:param float smoothing: smoothing rate (0.0 means the conventional CE)
:param bool normalize_length: normalize loss by sequence length if True
:param torch.nn.Module criterion: loss function to be smoothed
"""
def __init__(
self,
size,
padding_idx,
smoothing,
normalize_length=False,
criterion=nn.KLDivLoss(reduction="none"),
):
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x, target):
"""Compute loss between x and target.
:param torch.Tensor x: prediction (batch, seqlen, class)
:param torch.Tensor target:
target signal masked with self.padding_id (batch, seqlen)
:return: scalar float value
:rtype torch.Tensor
"""
assert x.size(2) == self.size
batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
denom = total if self.normalize_length else batch_size
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Layer normalization module."""
import torch
class LayerNorm(torch.nn.LayerNorm):
"""Layer normalization module.
Args:
nout (int): Output dim size.
dim (int): Dimension to be normalized.
"""
def __init__(self, nout, dim=-1):
"""Construct an LayerNorm object."""
super(LayerNorm, self).__init__(nout, eps=1e-12)
self.dim = dim
def forward(self, x):
"""Apply layer normalization.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Normalized tensor.
"""
if self.dim == -1:
return super(LayerNorm, self).forward(x)
return (
super(LayerNorm, self)
.forward(x.transpose(self.dim, -1))
.transpose(self.dim, -1)
)
"""Lightweight Convolution Module."""
import numpy
import torch
import torch.nn.functional as F
from torch import nn
MIN_VALUE = float(numpy.finfo(numpy.float32).min)
class LightweightConvolution(nn.Module):
"""Lightweight Convolution layer.
This implementation is based on
https://github.com/pytorch/fairseq/tree/master/fairseq
Args:
wshare (int): the number of kernel of convolution
n_feat (int): the number of features
dropout_rate (float): dropout_rate
kernel_size (int): kernel size (length)
use_kernel_mask (bool): Use causal mask or not for convolution kernel
use_bias (bool): Use bias term or not.
"""
def __init__(
self,
wshare,
n_feat,
dropout_rate,
kernel_size,
use_kernel_mask=False,
use_bias=False,
):
"""Construct Lightweight Convolution layer."""
super(LightweightConvolution, self).__init__()
assert n_feat % wshare == 0
self.wshare = wshare
self.use_kernel_mask = use_kernel_mask
self.dropout_rate = dropout_rate
self.kernel_size = kernel_size
self.padding_size = int(kernel_size / 2)
# linear -> GLU -> lightconv -> linear
self.linear1 = nn.Linear(n_feat, n_feat * 2)
self.linear2 = nn.Linear(n_feat, n_feat)
self.act = nn.GLU()
# lightconv related
self.weight = nn.Parameter(
torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1)
)
self.use_bias = use_bias
if self.use_bias:
self.bias = nn.Parameter(torch.Tensor(n_feat))
# mask of kernel
kernel_mask0 = torch.zeros(self.wshare, int(kernel_size / 2))
kernel_mask1 = torch.ones(self.wshare, int(kernel_size / 2 + 1))
self.kernel_mask = torch.cat((kernel_mask1, kernel_mask0), dim=-1).unsqueeze(1)
def forward(self, query, key, value, mask):
"""Forward of 'Lightweight Convolution'.
This function takes query, key and value but uses only query.
This is just for compatibility with self-attention layer (attention.py)
Args:
query (torch.Tensor): (batch, time1, d_model) input tensor
key (torch.Tensor): (batch, time2, d_model) NOT USED
value (torch.Tensor): (batch, time2, d_model) NOT USED
mask (torch.Tensor): (batch, time1, time2) mask
Return:
x (torch.Tensor): (batch, time1, d_model) output
"""
# linear -> GLU -> lightconv -> linear
x = query
B, T, C = x.size()
H = self.wshare
# first liner layer
x = self.linear1(x)
# GLU activation
x = self.act(x)
# lightconv
x = x.transpose(1, 2).contiguous().view(-1, H, T) # B x C x T
weight = F.dropout(self.weight, self.dropout_rate, training=self.training)
if self.use_kernel_mask:
self.kernel_mask = self.kernel_mask.to(x.device)
weight = weight.masked_fill(self.kernel_mask == 0.0, float("-inf"))
weight = F.softmax(weight, dim=-1)
x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view(
B, C, T
)
if self.use_bias:
x = x + self.bias.view(1, -1, 1)
x = x.transpose(1, 2) # B x T x C
if mask is not None and not self.use_kernel_mask:
mask = mask.transpose(-1, -2)
x = x.masked_fill(mask == 0, 0.0)
# second linear layer
x = self.linear2(x)
return x
"""Lightweight 2-Dimensional Convolution module."""
import numpy
import torch
import torch.nn.functional as F
from torch import nn
MIN_VALUE = float(numpy.finfo(numpy.float32).min)
class LightweightConvolution2D(nn.Module):
"""Lightweight 2-Dimensional Convolution layer.
This implementation is based on
https://github.com/pytorch/fairseq/tree/master/fairseq
Args:
wshare (int): the number of kernel of convolution
n_feat (int): the number of features
dropout_rate (float): dropout_rate
kernel_size (int): kernel size (length)
use_kernel_mask (bool): Use causal mask or not for convolution kernel
use_bias (bool): Use bias term or not.
"""
def __init__(
self,
wshare,
n_feat,
dropout_rate,
kernel_size,
use_kernel_mask=False,
use_bias=False,
):
"""Construct Lightweight 2-Dimensional Convolution layer."""
super(LightweightConvolution2D, self).__init__()
assert n_feat % wshare == 0
self.wshare = wshare
self.use_kernel_mask = use_kernel_mask
self.dropout_rate = dropout_rate
self.kernel_size = kernel_size
self.padding_size = int(kernel_size / 2)
# linear -> GLU -> lightconv -> linear
self.linear1 = nn.Linear(n_feat, n_feat * 2)
self.linear2 = nn.Linear(n_feat * 2, n_feat)
self.act = nn.GLU()
# lightconv related
self.weight = nn.Parameter(
torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1)
)
self.weight_f = nn.Parameter(torch.Tensor(1, 1, kernel_size).uniform_(0, 1))
self.use_bias = use_bias
if self.use_bias:
self.bias = nn.Parameter(torch.Tensor(n_feat))
# mask of kernel
kernel_mask0 = torch.zeros(self.wshare, int(kernel_size / 2))
kernel_mask1 = torch.ones(self.wshare, int(kernel_size / 2 + 1))
self.kernel_mask = torch.cat((kernel_mask1, kernel_mask0), dim=-1).unsqueeze(1)
def forward(self, query, key, value, mask):
"""Forward of 'Lightweight 2-Dimensional Convolution'.
This function takes query, key and value but uses only query.
This is just for compatibility with self-attention layer (attention.py)
Args:
query (torch.Tensor): (batch, time1, d_model) input tensor
key (torch.Tensor): (batch, time2, d_model) NOT USED
value (torch.Tensor): (batch, time2, d_model) NOT USED
mask (torch.Tensor): (batch, time1, time2) mask
Return:
x (torch.Tensor): (batch, time1, d_model) output
"""
# linear -> GLU -> lightconv -> linear
x = query
B, T, C = x.size()
H = self.wshare
# first liner layer
x = self.linear1(x)
# GLU activation
x = self.act(x)
# convolution along frequency axis
weight_f = F.softmax(self.weight_f, dim=-1)
weight_f = F.dropout(weight_f, self.dropout_rate, training=self.training)
weight_new = torch.zeros(
B * T, 1, self.kernel_size, device=x.device, dtype=x.dtype
).copy_(weight_f)
xf = F.conv1d(
x.view(1, B * T, C), weight_new, padding=self.padding_size, groups=B * T
).view(B, T, C)
# lightconv
x = x.transpose(1, 2).contiguous().view(-1, H, T) # B x C x T
weight = F.dropout(self.weight, self.dropout_rate, training=self.training)
if self.use_kernel_mask:
self.kernel_mask = self.kernel_mask.to(x.device)
weight = weight.masked_fill(self.kernel_mask == 0.0, float("-inf"))
weight = F.softmax(weight, dim=-1)
x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view(
B, C, T
)
if self.use_bias:
x = x + self.bias.view(1, -1, 1)
x = x.transpose(1, 2) # B x T x C
x = torch.cat((x, xf), -1) # B x T x Cx2
if mask is not None and not self.use_kernel_mask:
mask = mask.transpose(-1, -2)
x = x.masked_fill(mask == 0, 0.0)
# second linear layer
x = self.linear2(x)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment