Commit a7785cc6 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

delete soft link

parent 9a2a05ca
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2022 58.com(Wuba) Inc AI Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Head Attention layer definition."""
import math
from typing import Tuple, Optional
import torch
from torch import nn
import torch.nn.functional as F
from wenet.transformer.attention import MultiHeadedAttention
class GroupedRelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper:
https://arxiv.org/abs/1901.02860
https://arxiv.org/abs/2109.01163
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self, n_head, n_feat, dropout_rate, group_size=3):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
self.group_size = group_size
self.d_k = n_feat // n_head # for GroupedAttention
self.n_feat = n_feat
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k * self.group_size))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k * self.group_size))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
def rel_shift(self, x, zero_triu: bool = False):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size()[0],
x.size()[1],
x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def pad4group(self, Q, K, V, P, mask, group_size: int = 3):
"""
q: (#batch, time1, size) -> (#batch, head, time1, size/head)
k,v: (#batch, time2, size) -> (#batch, head, time2, size/head)
p: (#batch, time2, size)
"""
# Compute Overflows
overflow_Q = Q.size(2) % group_size
overflow_KV = K.size(2) % group_size
padding_Q = (group_size - overflow_Q) * int(
overflow_Q // (overflow_Q + 0.00000000000000001))
padding_KV = (group_size - overflow_KV) * int(
overflow_KV // (overflow_KV + 0.00000000000000001))
batch_size, _, seq_len_KV, _ = K.size()
# Input Padding (B, T, D) -> (B, T + P, D)
Q = F.pad(Q, (0, 0, 0, padding_Q), value=0.0)
K = F.pad(K, (0, 0, 0, padding_KV), value=0.0)
V = F.pad(V, (0, 0, 0, padding_KV), value=0.0)
if mask is not None and mask.size(2) > 0 : # time2 > 0:
mask = mask[:, ::group_size, ::group_size]
Q = Q.transpose(1, 2).contiguous().view(
batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2)
K = K.transpose(1, 2).contiguous().view(
batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2)
V = V.transpose(1, 2).contiguous().view(
batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2)
# process pos_emb
P_batch_size = P.size(0)
overflow_P = P.size(1) % group_size
padding_P = group_size - overflow_P if overflow_P else 0
P = F.pad(P, (0, 0, 0, padding_P), value=0.0)
P = P.view(P_batch_size, -1, self.h, self.d_k * group_size).transpose(1, 2)
return Q, K, V, P, mask, padding_Q
def forward_attention(
self, value: torch.Tensor, scores: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
padding_q: Optional[int] = None
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
padding_q : for GroupedAttention in efficent conformer
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if mask.size(2) > 0 : # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
# n_feat!=h*d_k may be happened in GroupAttention
x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.n_feat)
) # (batch, time1, d_model)
if padding_q is not None:
# for GroupedAttention in efficent conformer
x = x[:, :x.size(1) - padding_q]
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q = self.linear_q(query)
k = self.linear_k(key) # (#batch, time2, size)
v = self.linear_v(value)
p = self.linear_pos(pos_emb) # (#batch, time2, size)
batch_size, seq_len_KV, _ = k.size() # seq_len_KV = time2
# (#batch, time2, size) -> (#batch, head, time2, size/head)
q = q.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
k = k.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
v = v.view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
if cache.size(0) > 0:
# use attention cache
key_cache, value_cache = torch.split(
cache, cache.size(-1) // 2, dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
new_cache = torch.cat((k, v), dim=-1)
# May be k and p does not match. eg. time2=18+18/2=27 > mask=36/2=18
if mask is not None and mask.size(2) > 0:
time2 = mask.size(2)
k = k[:, :, -time2:, :]
v = v[:, :, -time2:, :]
# q k v p: (batch, head, time1, d_k)
q, k, v, p, mask, padding_q = self.pad4group(q, k, v, p, mask, self.group_size)
# q_with_bias_u & q_with_bias_v = (batch, head, time1, d_k)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k * self.group_size) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask, padding_q), new_cache
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2022 58.com(Wuba) Inc AI Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""ConvolutionModule definition."""
from typing import Tuple
import torch
from torch import nn
from typeguard import check_argument_types
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = nn.ReLU(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True,
stride: int = 1):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
stride (int): Stride Convolution, for efficient Conformer
"""
assert check_argument_types()
super().__init__()
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=stride, # for depthwise_conv in StrideConv
padding=padding,
groups=channels,
bias=bias,
)
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
self.stride = stride
def forward(
self,
x: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
cache: torch.Tensor = torch.zeros((0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) # (#batch, channels, time)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
if self.lorder > 0:
if cache.size(2) == 0: # cache_t == 0
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
else:
# When export ONNX,the first cache is not None but all-zero,
# cause shape error in residual block,
# eg. cache14 + x9 = 23, 23-7+1=17 != 9
cache = cache[:, :, -self.lorder:]
assert cache.size(0) == x.size(0) # equal batch
assert cache.size(1) == x.size(1) # equal channel
x = torch.cat((cache, x), dim=2)
assert (x.size(2) > self.lorder)
new_cache = x[:, :, -self.lorder:]
else:
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
if mask_pad.size(2) != x.size(2):
mask_pad = mask_pad[:, :, ::self.stride]
x.masked_fill_(~mask_pad, 0.0)
return x.transpose(1, 2), new_cache
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2022 58.com(Wuba) Inc AI Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from EfficientConformer(https://github.com/burchim/EfficientConformer)
# Paper(https://arxiv.org/abs/2109.01163)
"""Encoder definition."""
from typing import Tuple, Optional, List, Union
import torch
import logging
from typeguard import check_argument_types
import torch.nn.functional as F
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.transformer.embedding import PositionalEncoding
from wenet.transformer.embedding import RelPositionalEncoding
from wenet.transformer.embedding import NoPositionalEncoding
from wenet.transformer.subsampling import Conv2dSubsampling4
from wenet.transformer.subsampling import Conv2dSubsampling6
from wenet.transformer.subsampling import Conv2dSubsampling8
from wenet.transformer.subsampling import LinearNoSubsampling
from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.attention import RelPositionMultiHeadedAttention
from wenet.transformer.encoder_layer import ConformerEncoderLayer
from wenet.efficient_conformer.subsampling import Conv2dSubsampling2
from wenet.efficient_conformer.convolution import ConvolutionModule
from wenet.efficient_conformer.attention import GroupedRelPositionMultiHeadedAttention
from wenet.efficient_conformer.encoder_layer import StrideConformerEncoderLayer
from wenet.utils.common import get_activation
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
class EfficientConformerEncoder(torch.nn.Module):
"""Conformer encoder module."""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
pos_enc_layer_type: str = "rel_pos",
normalize_before: bool = True,
concat_after: bool = False,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
macaron_style: bool = True,
activation_type: str = "swish",
use_cnn_module: bool = True,
cnn_module_kernel: int = 15,
causal: bool = False,
cnn_module_norm: str = "batch_norm",
stride_layer_idx: Optional[Union[int, List[int]]] = 3,
stride: Optional[Union[int, List[int]]] = 2,
group_layer_idx: Optional[Union[int, List[int], tuple]] = (0, 1, 2, 3),
group_size: int = 3,
stride_kernel: bool = True,
**kwargs
):
"""Construct Efficient Conformer Encoder
Args:
input_size to use_dynamic_chunk, see in BaseEncoder
macaron_style (bool): Whether to use macaron style for
positionwise layer.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): whether to use causal convolution or not.
stride_layer_idx (list): layer id with StrideConv, start from 0
stride (list): stride size of each StrideConv in efficient conformer
group_layer_idx (list): layer id with GroupedAttention, start from 0
group_size (int): group size of every GroupedAttention layer
stride_kernel (bool): default True. True: recompute cnn kernels with stride.
"""
assert check_argument_types()
super().__init__()
self._output_size = output_size
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
subsampling_class = LinearNoSubsampling
elif input_layer == "conv2d2":
subsampling_class = Conv2dSubsampling2
elif input_layer == "conv2d":
subsampling_class = Conv2dSubsampling4
elif input_layer == "conv2d6":
subsampling_class = Conv2dSubsampling6
elif input_layer == "conv2d8":
subsampling_class = Conv2dSubsampling8
else:
raise ValueError("unknown input_layer: " + input_layer)
logging.info(f"input_layer = {input_layer}, "
f"subsampling_class = {subsampling_class}")
self.global_cmvn = global_cmvn
self.embed = subsampling_class(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
self.input_layer = input_layer
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
activation = get_activation(activation_type)
self.num_blocks = num_blocks
self.attention_heads = attention_heads
self.cnn_module_kernel = cnn_module_kernel
self.global_chunk_size = 0
# efficient conformer configs
self.stride_layer_idx = [stride_layer_idx] \
if type(stride_layer_idx) == int else stride_layer_idx
self.stride = [stride] \
if type(stride) == int else stride
self.group_layer_idx = [group_layer_idx] \
if type(group_layer_idx) == int else group_layer_idx
self.grouped_size = group_size # group size of every GroupedAttention layer
assert len(self.stride) == len(self.stride_layer_idx)
self.cnn_module_kernels = [cnn_module_kernel] # kernel size of each StridedConv
for i in self.stride:
if stride_kernel:
self.cnn_module_kernels.append(self.cnn_module_kernels[-1] // i)
else:
self.cnn_module_kernels.append(self.cnn_module_kernels[-1])
logging.info(f"stride_layer_idx= {self.stride_layer_idx}, "
f"stride = {self.stride}, "
f"cnn_module_kernel = {self.cnn_module_kernels}, "
f"group_layer_idx = {self.group_layer_idx}, "
f"grouped_size = {self.grouped_size}")
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
# convolution module definition
convolution_layer = ConvolutionModule
# encoder definition
index = 0
layers = []
for i in range(num_blocks):
# self-attention module definition
if i in self.group_layer_idx:
encoder_selfattn_layer = GroupedRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
self.grouped_size)
else:
if pos_enc_layer_type == "no_pos":
encoder_selfattn_layer = MultiHeadedAttention
else:
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate)
# conformer module definition
if i in self.stride_layer_idx:
# conformer block with downsampling
convolution_layer_args_stride = (
output_size, self.cnn_module_kernels[index], activation,
cnn_module_norm, causal, True, self.stride[index])
layers.append(StrideConformerEncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(
*positionwise_layer_args) if macaron_style else None,
convolution_layer(
*convolution_layer_args_stride) if use_cnn_module else None,
torch.nn.AvgPool1d(
kernel_size=self.stride[index], stride=self.stride[index],
padding=0, ceil_mode=True,
count_include_pad=False), # pointwise_conv_layer
dropout_rate,
normalize_before,
concat_after,
))
index = index + 1
else:
# conformer block
convolution_layer_args_normal = (
output_size, self.cnn_module_kernels[index], activation,
cnn_module_norm, causal)
layers.append(ConformerEncoderLayer(
output_size,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(
*positionwise_layer_args) if macaron_style else None,
convolution_layer(
*convolution_layer_args_normal) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
))
self.encoders = torch.nn.ModuleList(layers)
def set_global_chunk_size(self, chunk_size):
"""Used in ONNX export.
"""
logging.info(f"set global chunk size: {chunk_size}, default is 0.")
self.global_chunk_size = chunk_size
def output_size(self) -> int:
return self._output_size
def calculate_downsampling_factor(self, i: int) -> int:
factor = 1
for idx, stride_idx in enumerate(self.stride_layer_idx):
if i > stride_idx:
factor *= self.stride[idx]
return factor
def forward(self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
index = 0 # traverse stride
for i, layer in enumerate(self.encoders):
# layer return : x, mask, new_att_cache, new_cnn_cache
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if i in self.stride_layer_idx:
masks = masks[:, :, ::self.stride[index]]
chunk_masks = chunk_masks[:, ::self.stride[index],
::self.stride[index]]
mask_pad = masks
pos_emb = pos_emb[:, ::self.stride[index], :]
index = index + 1
if self.normalize_before:
xs = self.after_norm(xs)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return xs, masks
def forward_chunk(
self,
xs: torch.Tensor,
offset: int,
required_cache_size: int,
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
att_mask : mask matrix of self attention
Returns:
torch.Tensor: output of current input xs
torch.Tensor: subsampling cache required for next chunk computation
List[torch.Tensor]: encoder layers output cache required for next
chunk computation
List[torch.Tensor]: conformer cnn cache
"""
assert xs.size(0) == 1
# using downsampling factor to recover offset
offset *= self.calculate_downsampling_factor(self.num_blocks + 1)
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1) # (1, 1, xs-time)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
chunk_size = xs.size(1)
attention_key_size = cache_t1 + chunk_size
pos_emb = self.embed.position_encoding(
offset=offset - cache_t1, size=attention_key_size)
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
# shape(pos_emb) = (b=1, chunk_size, emb_size=output_size=hidden-dim)
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = attention_key_size
else:
next_cache_start = max(attention_key_size - required_cache_size, 0)
# for ONNX export, padding xs to chunk_size
if self.global_chunk_size > 0:
real_len = xs.size(1)
xs = F.pad(xs, (0, 0, 0, self.global_chunk_size - real_len), value=0.0)
tmp_zeros = torch.zeros(att_mask.shape, dtype=torch.bool)
att_mask[:, :, required_cache_size + real_len + 1:] = \
tmp_zeros[:, :, required_cache_size + real_len + 1:]
r_att_cache = []
r_cnn_cache = []
mask_pad = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
mask_pad = mask_pad.unsqueeze(1) # batchPad (b=1, 1, time=chunk_size)
max_att_len, max_cnn_len = 0, 0 # for repeat_interleave of new_att_cache
for i, layer in enumerate(self.encoders):
factor = self.calculate_downsampling_factor(i)
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
# shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ]
xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
mask_pad=mask_pad,
att_cache=att_cache[i:i + 1, :, ::factor, :],
cnn_cache=cnn_cache[i, :, :, :]
if cnn_cache.size(0) > 0 else cnn_cache
)
if i in self.stride_layer_idx:
# compute time dimension for next block
efficient_index = self.stride_layer_idx.index(i)
att_mask = att_mask[:, ::self.stride[efficient_index],
::self.stride[efficient_index]]
mask_pad = mask_pad[:, ::self.stride[efficient_index],
::self.stride[efficient_index]]
pos_emb = pos_emb[:, ::self.stride[efficient_index], :]
# shape(new_att_cache) = [batch, head, time2, outdim]
new_att_cache = new_att_cache[:, :, next_cache_start // factor:, :]
# shape(new_cnn_cache) = [1, batch, outdim, cache_t2]
new_cnn_cache = new_cnn_cache.unsqueeze(0)
# use repeat_interleave to new_att_cache
new_att_cache = new_att_cache.repeat_interleave(repeats=factor, dim=2)
# padding new_cnn_cache to cnn.lorder for casual convolution
new_cnn_cache = F.pad(
new_cnn_cache,
(self.cnn_module_kernel - 1 - new_cnn_cache.size(3), 0))
if i == 0:
# record length for the first block as max length
max_att_len = new_att_cache.size(2)
max_cnn_len = new_cnn_cache.size(3)
# update real shape of att_cache and cnn_cache
r_att_cache.append(new_att_cache[:, :, -max_att_len:, :])
r_cnn_cache.append(new_cnn_cache[:, :, :, -max_cnn_len:])
if self.normalize_before:
xs = self.after_norm(xs)
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
# ? may be larger than cache_t1, it depends on required_cache_size
r_att_cache = torch.cat(r_att_cache, dim=0)
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
return xs, r_att_cache, r_cnn_cache
def forward_chunk_by_chunk(
self,
xs: torch.Tensor,
decoding_chunk_size: int,
num_decoding_left_chunks: int = -1,
use_onnx=False
) -> Tuple[torch.Tensor, torch.Tensor]:
""" Forward input chunk by chunk with chunk_size like a streaming
fashion
Here we should pay special attention to computation cache in the
streaming style forward chunk by chunk. Three things should be taken
into account for computation in the current network:
1. transformer/conformer encoder layers output cache
2. convolution in conformer
3. convolution in subsampling
However, we don't implement subsampling cache for:
1. We can control subsampling module to output the right result by
overlapping input instead of cache left context, even though it
wastes some computation, but subsampling only takes a very
small fraction of computation in the whole model.
2. Typically, there are several covolution layers with subsampling
in subsampling module, it is tricky and complicated to do cache
with different convolution layers with different subsampling
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
Args:
xs (torch.Tensor): (1, max_len, dim)
decoding_chunk_size (int): decoding chunk size
num_decoding_left_chunks (int):
use_onnx (bool): True for simulating ONNX model inference.
"""
assert decoding_chunk_size > 0
# The model is trained by static or dynamic chunk
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
subsampling = self.embed.subsampling_rate
context = self.embed.right_context + 1 # Add current frame
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.size(1)
outputs = []
offset = 0
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
if use_onnx:
logging.info("Simulating for ONNX runtime ...")
att_cache: torch.Tensor = torch.zeros(
(self.num_blocks, self.attention_heads, required_cache_size,
self.output_size() // self.attention_heads * 2),
device=xs.device)
cnn_cache: torch.Tensor = torch.zeros(
(self.num_blocks, 1, self.output_size(), self.cnn_module_kernel - 1),
device=xs.device)
self.set_global_chunk_size(chunk_size=18)
else:
logging.info("Simulating for JIT runtime ...")
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
# Feed forward overlap input step by step
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
logging.info(f"-->> frame chunk msg: cur={cur}, "
f"end={end}, num_frames={end-cur}, "
f"decoding_window={decoding_window}")
if use_onnx:
att_mask: torch.Tensor = torch.ones(
(1, 1, required_cache_size + decoding_chunk_size),
dtype=torch.bool, device=xs.device)
if cur == 0:
att_mask[:, :, :required_cache_size] = 0
else:
att_mask: torch.Tensor = torch.ones(
(0, 0, 0), dtype=torch.bool, device=xs.device)
chunk_xs = xs[:, cur:end, :]
(y, att_cache, cnn_cache) = \
self.forward_chunk(
chunk_xs, offset, required_cache_size,
att_cache, cnn_cache, att_mask)
outputs.append(y)
offset += y.size(1)
ys = torch.cat(outputs, 1)
masks = torch.ones(1, 1, ys.size(1), device=ys.device, dtype=torch.bool)
return ys, masks
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2022 58.com(Wuba) Inc AI Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder self-attention layer definition."""
from typing import Optional, Tuple
import torch
from torch import nn
class StrideConformerEncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def __init__(
self,
size: int,
self_attn: torch.nn.Module,
feed_forward: Optional[nn.Module] = None,
feed_forward_macaron: Optional[nn.Module] = None,
conv_module: Optional[nn.Module] = None,
pointwise_conv_layer: Optional[nn.Module] = None,
dropout_rate: float = 0.1,
normalize_before: bool = True,
concat_after: bool = False,
):
"""Construct an EncoderLayer object."""
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.pointwise_conv_layer = pointwise_conv_layer
self.norm_ff = nn.LayerNorm(size, eps=1e-5) # for the FNN module
self.norm_mha = nn.LayerNorm(size, eps=1e-5) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-5)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = nn.LayerNorm(size,
eps=1e-5) # for the CNN module
self.norm_final = nn.LayerNorm(
size, eps=1e-5) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
self.concat_linear = nn.Linear(size + size, size)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
x_att, new_att_cache = self.self_attn(
x, x, x, mask, pos_emb, att_cache)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
# add pointwise_conv for efficient conformer
# pointwise_conv_layer does not change shape
if self.pointwise_conv_layer is not None:
residual = residual.transpose(1, 2)
residual = self.pointwise_conv_layer(residual)
residual = residual.transpose(1, 2)
assert residual.size(0) == x.size(0)
assert residual.size(1) == x.size(1)
assert residual.size(2) == x.size(2)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
return x, mask, new_att_cache, new_cnn_cache
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 58.com(Wuba) Inc AI Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Subsampling layer definition."""
from typing import Tuple, Union
import torch
from wenet.transformer.subsampling import BaseSubsampling
class Conv2dSubsampling2(BaseSubsampling):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def __init__(self, idim: int, odim: int, dropout_rate: float,
pos_enc_class: torch.nn.Module):
"""Construct an Conv2dSubsampling4 object."""
super().__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 2),
torch.nn.ReLU()
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * ((idim - 1) // 2), odim))
self.pos_enc = pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self.subsampling_rate = 2
# 2 = (3 - 1) * 1
self.right_context = 2
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
offset: Union[int, torch.Tensor] = 0
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
torch.Tensor: positional encoding
"""
x = x.unsqueeze(1) # (b, c=1, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask[:, :, :-2:2]
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2022 Ximalaya Inc. (Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Head Attention layer definition."""
import math
import torch
import torch.nn as nn
from wenet.transformer.attention import MultiHeadedAttention
from typing import Tuple
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def __init__(self, n_head, n_feat, dropout_rate,
do_rel_shift=False, adaptive_scale=False, init_weights=False):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self.do_rel_shift = do_rel_shift
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
torch.nn.init.xavier_uniform_(self.pos_bias_u)
torch.nn.init.xavier_uniform_(self.pos_bias_v)
self.adaptive_scale = adaptive_scale
self.ada_scale = nn.Parameter(
torch.ones([1, 1, n_feat]), requires_grad=adaptive_scale)
self.ada_bias = nn.Parameter(
torch.zeros([1, 1, n_feat]), requires_grad=adaptive_scale)
if init_weights:
self.init_weights()
def init_weights(self):
input_max = (self.h * self.d_k) ** -0.5
torch.nn.init.uniform_(self.linear_q.weight, -input_max, input_max)
torch.nn.init.uniform_(self.linear_q.bias, -input_max, input_max)
torch.nn.init.uniform_(self.linear_k.weight, -input_max, input_max)
torch.nn.init.uniform_(self.linear_k.bias, -input_max, input_max)
torch.nn.init.uniform_(self.linear_v.weight, -input_max, input_max)
torch.nn.init.uniform_(self.linear_v.bias, -input_max, input_max)
torch.nn.init.uniform_(self.linear_pos.weight, -input_max, input_max)
torch.nn.init.uniform_(self.linear_out.weight, -input_max, input_max)
torch.nn.init.uniform_(self.linear_out.bias, -input_max, input_max)
def rel_shift(self, x, zero_triu: bool = False):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.size()[0],
x.size()[1],
x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
if zero_triu:
ones = torch.ones((x.size(2), x.size(3)))
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
return x
def forward_attention(
self, value: torch.Tensor, scores: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if mask.size(2) > 0: # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
# (batch, head, time1, time2)
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h * self.d_k)
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(self, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
if self.adaptive_scale:
query = self.ada_scale * query + self.ada_bias
key = self.ada_scale * key + self.ada_bias
value = self.ada_scale * value + self.ada_bias
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if cache.size(0) > 0:
key_cache, value_cache = torch.split(
cache, cache.size(-1) // 2, dim=-1)
k = torch.cat([key_cache, k], dim=2)
v = torch.cat([value_cache, v], dim=2)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
if self.do_rel_shift:
matrix_bd = self.rel_shift(matrix_bd)
scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask), new_cache
# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Conv2d Module with Valid Padding"""
import torch.nn.functional as F
from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional
class Conv2dValid(_ConvNd):
"""
Conv2d operator for VALID mode padding.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None,
valid_trigx: bool = False,
valid_trigy: bool = False
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size_ = _pair(kernel_size)
stride_ = _pair(stride)
padding_ = padding if isinstance(padding, str) else _pair(padding)
dilation_ = _pair(dilation)
super(Conv2dValid, self).__init__(
in_channels, out_channels, kernel_size_,
stride_, padding_, dilation_, False, _pair(0),
groups, bias, padding_mode, **factory_kwargs)
self.valid_trigx = valid_trigx
self.valid_trigy = valid_trigy
def _conv_forward(
self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
validx, validy = 0, 0
if self.valid_trigx:
validx = (input.size(-2) * (self.stride[-2] - 1) - 1
+ self.kernel_size[-2]) // 2
if self.valid_trigy:
validy = (input.size(-1) * (self.stride[-1] - 1) - 1
+ self.kernel_size[-1]) // 2
return F.conv2d(input, weight, bias, self.stride,
(validx, validy), self.dilation, self.groups)
def forward(self, input: Tensor) -> Tensor:
return self._conv_forward(input, self.weight, self.bias)
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2022 Ximalaya Inc. (authors: Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""ConvolutionModule definition."""
from typing import Tuple
import torch
from torch import nn
from typeguard import check_argument_types
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
def __init__(self,
channels: int,
kernel_size: int = 15,
activation: nn.Module = nn.ReLU(),
norm: str = "batch_norm",
causal: bool = False,
bias: bool = True,
adaptive_scale: bool = False,
init_weights: bool = False
):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
assert check_argument_types()
super().__init__()
self.bias = bias
self.channels = channels
self.kernel_size = kernel_size
self.adaptive_scale = adaptive_scale
self.ada_scale = torch.nn.Parameter(
torch.ones([1, 1, channels]), requires_grad=adaptive_scale)
self.ada_bias = torch.nn.Parameter(
torch.zeros([1, 1, channels]), requires_grad=adaptive_scale)
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if causal:
padding = 0
self.lorder = kernel_size - 1
else:
# kernel_size should be an odd number for none causal convolution
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.lorder = 0
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=padding,
groups=channels,
bias=bias,
)
assert norm in ['batch_norm', 'layer_norm']
if norm == "batch_norm":
self.use_layer_norm = False
self.norm = nn.BatchNorm1d(channels)
else:
self.use_layer_norm = True
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
if init_weights:
self.init_weights()
def init_weights(self):
pw_max = self.channels ** -0.5
dw_max = self.kernel_size ** -0.5
torch.nn.init.uniform_(self.pointwise_conv1.weight.data, -pw_max, pw_max)
if self.bias:
torch.nn.init.uniform_(self.pointwise_conv1.bias.data, -pw_max, pw_max)
torch.nn.init.uniform_(self.depthwise_conv.weight.data, -dw_max, dw_max)
if self.bias:
torch.nn.init.uniform_(self.depthwise_conv.bias.data, -dw_max, dw_max)
torch.nn.init.uniform_(self.pointwise_conv2.weight.data, -pw_max, pw_max)
if self.bias:
torch.nn.init.uniform_(self.pointwise_conv2.bias.data, -pw_max, pw_max)
def forward(
self,
x: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
cache: torch.Tensor = torch.zeros((0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
if self.adaptive_scale:
x = self.ada_scale * x + self.ada_bias
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2) # (#batch, channels, time)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
if self.lorder > 0:
if cache.size(2) == 0: # cache_t == 0
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
else:
assert cache.size(0) == x.size(0) # equal batch
assert cache.size(1) == x.size(1) # equal channel
x = torch.cat((cache, x), dim=2)
assert (x.size(2) > self.lorder)
new_cache = x[:, :, -self.lorder:]
else:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.activation(self.norm(x))
if self.use_layer_norm:
x = x.transpose(1, 2)
x = self.pointwise_conv2(x)
# mask batch padding
if mask_pad.size(2) > 0: # time > 0
x.masked_fill_(~mask_pad, 0.0)
return x.transpose(1, 2), new_cache
# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer)
# Squeezeformer(https://github.com/upskyy/Squeezeformer)
# NeMo(https://github.com/NVIDIA/NeMo)
import torch
import torch.nn as nn
from typing import Tuple, Union, Optional, List
from wenet.squeezeformer.subsampling \
import DepthwiseConv2dSubsampling4, TimeReductionLayer1D, \
TimeReductionLayer2D, TimeReductionLayerStream
from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer
from wenet.transformer.embedding import RelPositionalEncoding
from wenet.transformer.attention import MultiHeadedAttention
from wenet.squeezeformer.attention import RelPositionMultiHeadedAttention
from wenet.squeezeformer.positionwise_feed_forward \
import PositionwiseFeedForward
from wenet.squeezeformer.convolution import ConvolutionModule
from wenet.utils.mask import make_pad_mask, add_optional_chunk_mask
from wenet.utils.common import get_activation
class SqueezeformerEncoder(nn.Module):
def __init__(
self,
input_size: int = 80,
encoder_dim: int = 256,
output_size: int = 256,
attention_heads: int = 4,
num_blocks: int = 12,
reduce_idx: Optional[Union[int, List[int]]] = 5,
recover_idx: Optional[Union[int, List[int]]] = 11,
feed_forward_expansion_factor: int = 4,
dw_stride: bool = False,
input_dropout_rate: float = 0.1,
pos_enc_layer_type: str = "rel_pos",
time_reduction_layer_type: str = "conv1d",
do_rel_shift: bool = True,
feed_forward_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.1,
cnn_module_kernel: int = 31,
cnn_norm_type: str = "batch_norm",
dropout: float = 0.1,
causal: bool = False,
adaptive_scale: bool = True,
activation_type: str = "swish",
init_weights: bool = True,
global_cmvn: torch.nn.Module = None,
normalize_before: bool = False,
use_dynamic_chunk: bool = False,
concat_after: bool = False,
static_chunk_size: int = 0,
use_dynamic_left_chunk: bool = False
):
"""Construct SqueezeformerEncoder
Args:
input_size to use_dynamic_chunk, see in Transformer BaseEncoder.
encoder_dim (int): The hidden dimension of encoder layer.
output_size (int): The output dimension of final projection layer.
attention_heads (int): Num of attention head in attention module.
num_blocks (int): Num of encoder layers.
reduce_idx Optional[Union[int, List[int]]]:
reduce layer index, from 40ms to 80ms per frame.
recover_idx Optional[Union[int, List[int]]]:
recover layer index, from 80ms to 40ms per frame.
feed_forward_expansion_factor (int): Enlarge coefficient of FFN.
dw_stride (bool): Whether do depthwise convolution
on subsampling module.
input_dropout_rate (float): Dropout rate of input projection layer.
pos_enc_layer_type (str): Self attention type.
time_reduction_layer_type (str): Conv1d or Conv2d reduction layer.
do_rel_shift (bool): Whether to do relative shift
operation on rel-attention module.
cnn_module_kernel (int): Kernel size of CNN module.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
cnn_module_kernel (int): Kernel size of convolution module.
adaptive_scale (bool): Whether to use adaptive scale.
init_weights (bool): Whether to initialize weights.
causal (bool): whether to use causal convolution or not.
"""
super(SqueezeformerEncoder, self).__init__()
self.global_cmvn = global_cmvn
self.reduce_idx: Optional[Union[int, List[int]]] = [reduce_idx] \
if type(reduce_idx) == int else reduce_idx
self.recover_idx: Optional[Union[int, List[int]]] = [recover_idx] \
if type(recover_idx) == int else recover_idx
self.check_ascending_list()
if reduce_idx is None:
self.time_reduce = None
else:
if recover_idx is None:
self.time_reduce = 'normal' # no recovery at the end
else:
self.time_reduce = 'recover' # recovery at the end
assert len(self.reduce_idx) == len(self.recover_idx)
self.reduce_stride = 2
self._output_size = output_size
self.normalize_before = normalize_before
self.static_chunk_size = static_chunk_size
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
self.pos_enc_layer_type = pos_enc_layer_type
activation = get_activation(activation_type)
# self-attention module definition
if pos_enc_layer_type != "rel_pos":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
else:
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
encoder_dim,
attention_dropout_rate,
do_rel_shift,
adaptive_scale,
init_weights
)
# feed-forward module definition
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
encoder_dim,
encoder_dim * feed_forward_expansion_factor,
feed_forward_dropout_rate,
activation,
adaptive_scale,
init_weights
)
# convolution module definition
convolution_layer = ConvolutionModule
convolution_layer_args = (
encoder_dim, cnn_module_kernel, activation,
cnn_norm_type, causal, True, adaptive_scale, init_weights)
self.embed = DepthwiseConv2dSubsampling4(
1, encoder_dim,
RelPositionalEncoding(encoder_dim, dropout_rate=0.1),
dw_stride,
input_size,
input_dropout_rate,
init_weights
)
self.preln = nn.LayerNorm(encoder_dim)
self.encoders = torch.nn.ModuleList([SqueezeformerEncoderLayer(
encoder_dim,
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
convolution_layer(*convolution_layer_args),
positionwise_layer(*positionwise_layer_args),
normalize_before,
dropout,
concat_after) for _ in range(num_blocks)
])
if time_reduction_layer_type == 'conv1d':
time_reduction_layer = TimeReductionLayer1D
time_reduction_layer_args = {
'channel': encoder_dim,
'out_dim': encoder_dim,
}
elif time_reduction_layer_type == 'stream':
time_reduction_layer = TimeReductionLayerStream
time_reduction_layer_args = {
'channel': encoder_dim,
'out_dim': encoder_dim,
}
else:
time_reduction_layer = TimeReductionLayer2D
time_reduction_layer_args = {'encoder_dim': encoder_dim}
self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args)
self.time_recover_layer = nn.Linear(encoder_dim, encoder_dim)
self.final_proj = None
if output_size != encoder_dim:
self.final_proj = nn.Linear(encoder_dim, output_size)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
decoding_chunk_size: int = 0,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
xs, pos_emb, masks = self.embed(xs, masks)
mask_pad = masks # (B, 1, T/subsample_rate)
chunk_masks = add_optional_chunk_mask(xs, masks,
self.use_dynamic_chunk,
self.use_dynamic_left_chunk,
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
xs_lens = mask_pad.squeeze(1).sum(1)
xs = self.preln(xs)
recover_activations: \
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = []
index = 0
for i, layer in enumerate(self.encoders):
if self.reduce_idx is not None:
if self.time_reduce is not None and i in self.reduce_idx:
recover_activations.append((xs, chunk_masks, pos_emb, mask_pad))
xs, xs_lens, chunk_masks, mask_pad = \
self.time_reduction_layer(xs, xs_lens, chunk_masks, mask_pad)
pos_emb = pos_emb[:, ::2, :]
index += 1
if self.recover_idx is not None:
if self.time_reduce == 'recover' and i in self.recover_idx:
index -= 1
(recover_tensor, recover_chunk_masks,
recover_pos_emb, recover_mask_pad) \
= recover_activations[index]
# recover output length for ctc decode
xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2)
xs = self.time_recover_layer(xs)
recoverd_t = recover_tensor.size(1)
xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()
chunk_masks = recover_chunk_masks
pos_emb = recover_pos_emb
mask_pad = recover_mask_pad
xs = xs.masked_fill(~mask_pad[:, 0, :].unsqueeze(-1), 0.0)
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
if self.final_proj is not None:
xs = self.final_proj(xs)
return xs, masks
def check_ascending_list(self):
if self.reduce_idx is not None:
assert self.reduce_idx == sorted(self.reduce_idx), \
"reduce_idx should be int or ascending list"
if self.recover_idx is not None:
assert self.recover_idx == sorted(self.recover_idx), \
"recover_idx should be int or ascending list"
def calculate_downsampling_factor(self, i: int) -> int:
if self.reduce_idx is None:
return 1
else:
reduce_exp, recover_exp = 0, 0
for exp, rd_idx in enumerate(self.reduce_idx):
if i >= rd_idx:
reduce_exp = exp + 1
if self.recover_idx is not None:
for exp, rc_idx in enumerate(self.recover_idx):
if i >= rc_idx:
recover_exp = exp + 1
return int(2 ** (reduce_exp - recover_exp))
def forward_chunk(
self,
xs: torch.Tensor,
offset: int,
required_cache_size: int,
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, ?, d_k * 2)
depending on required_cache_size.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
assert xs.size(0) == 1
# tmp_masks is just for interface compatibility
tmp_masks = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
tmp_masks = tmp_masks.unsqueeze(1)
if self.global_cmvn is not None:
xs = self.global_cmvn(xs)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
chunk_size = xs.size(1)
attention_key_size = cache_t1 + chunk_size
pos_emb = self.embed.position_encoding(
offset=offset - cache_t1, size=attention_key_size)
if required_cache_size < 0:
next_cache_start = 0
elif required_cache_size == 0:
next_cache_start = attention_key_size
else:
next_cache_start = max(attention_key_size - required_cache_size, 0)
r_att_cache = []
r_cnn_cache = []
mask_pad = torch.ones(1,
xs.size(1),
device=xs.device,
dtype=torch.bool)
mask_pad = mask_pad.unsqueeze(1)
max_att_len: int = 0
recover_activations: \
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = []
index = 0
xs_lens = torch.tensor([xs.size(1)], device=xs.device, dtype=torch.int)
xs = self.preln(xs)
for i, layer in enumerate(self.encoders):
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
if self.reduce_idx is not None:
if self.time_reduce is not None and i in self.reduce_idx:
recover_activations.append((xs, att_mask, pos_emb, mask_pad))
xs, xs_lens, att_mask, mask_pad = \
self.time_reduction_layer(xs, xs_lens, att_mask, mask_pad)
pos_emb = pos_emb[:, ::2, :]
index += 1
if self.recover_idx is not None:
if self.time_reduce == 'recover' and i in self.recover_idx:
index -= 1
(recover_tensor, recover_att_mask,
recover_pos_emb, recover_mask_pad) \
= recover_activations[index]
# recover output length for ctc decode
xs = xs.unsqueeze(2).repeat(1, 1, 2, 1).flatten(1, 2)
xs = self.time_recover_layer(xs)
recoverd_t = recover_tensor.size(1)
xs = recover_tensor + xs[:, :recoverd_t, :].contiguous()
att_mask = recover_att_mask
pos_emb = recover_pos_emb
mask_pad = recover_mask_pad
if att_mask.size(1) != 0:
xs = xs.masked_fill(~att_mask[:, 0, :].unsqueeze(-1), 0.0)
factor = self.calculate_downsampling_factor(i)
xs, _, new_att_cache, new_cnn_cache = layer(
xs, att_mask, pos_emb,
att_cache=att_cache[i:i + 1][:, :, ::factor, :]
[:, :, :pos_emb.size(1) - xs.size(1), :] if
elayers > 0 else att_cache[:, :, ::factor, :],
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache
)
# NOTE(xcsong): After layer.forward
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
cached_att \
= new_att_cache[:, :, next_cache_start // factor:, :]
cached_cnn = new_cnn_cache.unsqueeze(0)
cached_att = cached_att.unsqueeze(3).\
repeat(1, 1, 1, factor, 1).flatten(2, 3)
if i == 0:
# record length for the first block as max length
max_att_len = cached_att.size(2)
r_att_cache.append(cached_att[:, :, :max_att_len, :])
r_cnn_cache.append(cached_cnn)
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
# ? may be larger than cache_t1, it depends on required_cache_size
r_att_cache = torch.cat(r_att_cache, dim=0)
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
if self.final_proj is not None:
xs = self.final_proj(xs)
return (xs, r_att_cache, r_cnn_cache)
def forward_chunk_by_chunk(
self,
xs: torch.Tensor,
decoding_chunk_size: int,
num_decoding_left_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
""" Forward input chunk by chunk with chunk_size like a streaming
fashion
Here we should pay special attention to computation cache in the
streaming style forward chunk by chunk. Three things should be taken
into account for computation in the current network:
1. transformer/conformer encoder layers output cache
2. convolution in conformer
3. convolution in subsampling
However, we don't implement subsampling cache for:
1. We can control subsampling module to output the right result by
overlapping input instead of cache left context, even though it
wastes some computation, but subsampling only takes a very
small fraction of computation in the whole model.
2. Typically, there are several covolution layers with subsampling
in subsampling module, it is tricky and complicated to do cache
with different convolution layers with different subsampling
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
Args:
xs (torch.Tensor): (1, max_len, dim)
chunk_size (int): decoding chunk size
"""
assert decoding_chunk_size > 0
# The model is trained by static or dynamic chunk
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
subsampling = self.embed.subsampling_rate
context = self.embed.right_context + 1 # Add current frame
stride = subsampling * decoding_chunk_size
decoding_window = (decoding_chunk_size - 1) * subsampling + context
num_frames = xs.size(1)
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
outputs = []
offset = 0
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
# Feed forward overlap input step by step
for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :]
(y, att_cache, cnn_cache) = \
self.forward_chunk(
chunk_xs, offset, required_cache_size,
att_cache, cnn_cache)
outputs.append(y)
offset += y.size(1)
ys = torch.cat(outputs, 1)
masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool)
return ys, masks
# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SqueezeformerEncoderLayer definition."""
import torch
import torch.nn as nn
from typing import Optional, Tuple
class SqueezeformerEncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward1 (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
feed_forward2 (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
"""
def __init__(
self,
size: int,
self_attn: torch.nn.Module,
feed_forward1: Optional[nn.Module] = None,
conv_module: Optional[nn.Module] = None,
feed_forward2: Optional[nn.Module] = None,
normalize_before: bool = False,
dropout_rate: float = 0.1,
concat_after: bool = False,
):
super(SqueezeformerEncoderLayer, self).__init__()
self.size = size
self.self_attn = self_attn
self.layer_norm1 = nn.LayerNorm(size)
self.ffn1 = feed_forward1
self.layer_norm2 = nn.LayerNorm(size)
self.conv_module = conv_module
self.layer_norm3 = nn.LayerNorm(size)
self.ffn2 = feed_forward2
self.layer_norm4 = nn.LayerNorm(size)
self.normalize_before = normalize_before
self.dropout = nn.Dropout(dropout_rate)
self.concat_after = concat_after
if concat_after:
self.concat_linear = nn.Linear(size + size, size)
else:
self.concat_linear = nn.Identity()
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor,
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# self attention module
residual = x
if self.normalize_before:
x = self.layer_norm1(x)
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.layer_norm1(x)
# ffn module
residual = x
if self.normalize_before:
x = self.layer_norm2(x)
x = self.ffn1(x)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.layer_norm2(x)
# conv module
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
residual = x
if self.normalize_before:
x = self.layer_norm3(x)
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.layer_norm3(x)
# ffn module
residual = x
if self.normalize_before:
x = self.layer_norm4(x)
x = self.ffn2(x)
# we do not use dropout here since it is inside feed forward function
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.layer_norm4(x)
return x, mask, new_att_cache, new_cnn_cache
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Ximalaya Inc (Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Positionwise feed forward layer definition."""
import torch
class PositionwiseFeedForward(torch.nn.Module):
"""Positionwise feed forward layer.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""
def __init__(self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
adaptive_scale: bool = False,
init_weights: bool = False
):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
self.idim = idim
self.hidden_units = hidden_units
self.w_1 = torch.nn.Linear(idim, hidden_units)
self.activation = activation
self.dropout = torch.nn.Dropout(dropout_rate)
self.w_2 = torch.nn.Linear(hidden_units, idim)
self.ada_scale = None
self.ada_bias = None
self.adaptive_scale = adaptive_scale
self.ada_scale = torch.nn.Parameter(
torch.ones([1, 1, idim]), requires_grad=adaptive_scale)
self.ada_bias = torch.nn.Parameter(
torch.zeros([1, 1, idim]), requires_grad=adaptive_scale)
if init_weights:
self.init_weights()
def init_weights(self):
ffn1_max = self.idim ** -0.5
ffn2_max = self.hidden_units ** -0.5
torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max)
torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max)
torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max)
torch.nn.init.uniform_(self.w_2.bias.data, -ffn2_max, ffn2_max)
def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
xs: input tensor (B, L, D)
Returns:
output tensor, (B, L, D)
"""
if self.adaptive_scale:
xs = self.ada_scale * xs + self.ada_bias
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment