Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Conformer_pytorch
Commits
a7785cc6
Commit
a7785cc6
authored
Mar 26, 2024
by
Sugon_ldc
Browse files
delete soft link
parent
9a2a05ca
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2334 additions
and
0 deletions
+2334
-0
examples/aishell/s0/wenet/efficient_conformer/__pycache__/encoder_layer.cpython-38.pyc
...icient_conformer/__pycache__/encoder_layer.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/efficient_conformer/__pycache__/subsampling.cpython-38.pyc
...fficient_conformer/__pycache__/subsampling.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/efficient_conformer/attention.py
examples/aishell/s0/wenet/efficient_conformer/attention.py
+245
-0
examples/aishell/s0/wenet/efficient_conformer/convolution.py
examples/aishell/s0/wenet/efficient_conformer/convolution.py
+156
-0
examples/aishell/s0/wenet/efficient_conformer/encoder.py
examples/aishell/s0/wenet/efficient_conformer/encoder.py
+543
-0
examples/aishell/s0/wenet/efficient_conformer/encoder_layer.py
...les/aishell/s0/wenet/efficient_conformer/encoder_layer.py
+178
-0
examples/aishell/s0/wenet/efficient_conformer/subsampling.py
examples/aishell/s0/wenet/efficient_conformer/subsampling.py
+77
-0
examples/aishell/s0/wenet/squeezeformer/__pycache__/attention.cpython-38.pyc
.../wenet/squeezeformer/__pycache__/attention.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/squeezeformer/__pycache__/conv2d.cpython-38.pyc
.../s0/wenet/squeezeformer/__pycache__/conv2d.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/squeezeformer/__pycache__/convolution.cpython-38.pyc
...enet/squeezeformer/__pycache__/convolution.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/squeezeformer/__pycache__/encoder.cpython-38.pyc
...s0/wenet/squeezeformer/__pycache__/encoder.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/squeezeformer/__pycache__/encoder_layer.cpython-38.pyc
...et/squeezeformer/__pycache__/encoder_layer.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/squeezeformer/__pycache__/positionwise_feed_forward.cpython-38.pyc
...rmer/__pycache__/positionwise_feed_forward.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/squeezeformer/__pycache__/subsampling.cpython-38.pyc
...enet/squeezeformer/__pycache__/subsampling.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/squeezeformer/attention.py
examples/aishell/s0/wenet/squeezeformer/attention.py
+222
-0
examples/aishell/s0/wenet/squeezeformer/conv2d.py
examples/aishell/s0/wenet/squeezeformer/conv2d.py
+66
-0
examples/aishell/s0/wenet/squeezeformer/convolution.py
examples/aishell/s0/wenet/squeezeformer/convolution.py
+174
-0
examples/aishell/s0/wenet/squeezeformer/encoder.py
examples/aishell/s0/wenet/squeezeformer/encoder.py
+473
-0
examples/aishell/s0/wenet/squeezeformer/encoder_layer.py
examples/aishell/s0/wenet/squeezeformer/encoder_layer.py
+121
-0
examples/aishell/s0/wenet/squeezeformer/positionwise_feed_forward.py
...shell/s0/wenet/squeezeformer/positionwise_feed_forward.py
+79
-0
No files found.
examples/aishell/s0/wenet/efficient_conformer/__pycache__/encoder_layer.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/efficient_conformer/__pycache__/subsampling.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/efficient_conformer/attention.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2022 58.com(Wuba) Inc AI Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Head Attention layer definition."""
import
math
from
typing
import
Tuple
,
Optional
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
wenet.transformer.attention
import
MultiHeadedAttention
class
GroupedRelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding.
Paper:
https://arxiv.org/abs/1901.02860
https://arxiv.org/abs/2109.01163
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
group_size
=
3
):
"""Construct an RelPositionMultiHeadedAttention object."""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
# linear transformation for positional encoding
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias
=
False
)
self
.
group_size
=
group_size
self
.
d_k
=
n_feat
//
n_head
# for GroupedAttention
self
.
n_feat
=
n_feat
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self
.
pos_bias_u
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
*
self
.
group_size
))
self
.
pos_bias_v
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
*
self
.
group_size
))
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_u
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_v
)
def
rel_shift
(
self
,
x
,
zero_triu
:
bool
=
False
):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad
=
torch
.
zeros
((
x
.
size
()[
0
],
x
.
size
()[
1
],
x
.
size
()[
2
],
1
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=-
1
)
x_padded
=
x_padded
.
view
(
x
.
size
()[
0
],
x
.
size
()[
1
],
x
.
size
(
3
)
+
1
,
x
.
size
(
2
))
x
=
x_padded
[:,
:,
1
:].
view_as
(
x
)
if
zero_triu
:
ones
=
torch
.
ones
((
x
.
size
(
2
),
x
.
size
(
3
)))
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
3
)
-
x
.
size
(
2
))[
None
,
None
,
:,
:]
return
x
def
pad4group
(
self
,
Q
,
K
,
V
,
P
,
mask
,
group_size
:
int
=
3
):
"""
q: (#batch, time1, size) -> (#batch, head, time1, size/head)
k,v: (#batch, time2, size) -> (#batch, head, time2, size/head)
p: (#batch, time2, size)
"""
# Compute Overflows
overflow_Q
=
Q
.
size
(
2
)
%
group_size
overflow_KV
=
K
.
size
(
2
)
%
group_size
padding_Q
=
(
group_size
-
overflow_Q
)
*
int
(
overflow_Q
//
(
overflow_Q
+
0.00000000000000001
))
padding_KV
=
(
group_size
-
overflow_KV
)
*
int
(
overflow_KV
//
(
overflow_KV
+
0.00000000000000001
))
batch_size
,
_
,
seq_len_KV
,
_
=
K
.
size
()
# Input Padding (B, T, D) -> (B, T + P, D)
Q
=
F
.
pad
(
Q
,
(
0
,
0
,
0
,
padding_Q
),
value
=
0.0
)
K
=
F
.
pad
(
K
,
(
0
,
0
,
0
,
padding_KV
),
value
=
0.0
)
V
=
F
.
pad
(
V
,
(
0
,
0
,
0
,
padding_KV
),
value
=
0.0
)
if
mask
is
not
None
and
mask
.
size
(
2
)
>
0
:
# time2 > 0:
mask
=
mask
[:,
::
group_size
,
::
group_size
]
Q
=
Q
.
transpose
(
1
,
2
).
contiguous
().
view
(
batch_size
,
-
1
,
self
.
h
,
self
.
d_k
*
group_size
).
transpose
(
1
,
2
)
K
=
K
.
transpose
(
1
,
2
).
contiguous
().
view
(
batch_size
,
-
1
,
self
.
h
,
self
.
d_k
*
group_size
).
transpose
(
1
,
2
)
V
=
V
.
transpose
(
1
,
2
).
contiguous
().
view
(
batch_size
,
-
1
,
self
.
h
,
self
.
d_k
*
group_size
).
transpose
(
1
,
2
)
# process pos_emb
P_batch_size
=
P
.
size
(
0
)
overflow_P
=
P
.
size
(
1
)
%
group_size
padding_P
=
group_size
-
overflow_P
if
overflow_P
else
0
P
=
F
.
pad
(
P
,
(
0
,
0
,
0
,
padding_P
),
value
=
0.0
)
P
=
P
.
view
(
P_batch_size
,
-
1
,
self
.
h
,
self
.
d_k
*
group_size
).
transpose
(
1
,
2
)
return
Q
,
K
,
V
,
P
,
mask
,
padding_Q
def
forward_attention
(
self
,
value
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
padding_q
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
padding_q : for GroupedAttention in efficent conformer
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch
=
value
.
size
(
0
)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if
mask
.
size
(
2
)
>
0
:
# time2 > 0
mask
=
mask
.
unsqueeze
(
1
).
eq
(
0
)
# (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask
=
mask
[:,
:,
:,
:
scores
.
size
(
-
1
)]
# (batch, 1, *, time2)
scores
=
scores
.
masked_fill
(
mask
,
-
float
(
'inf'
))
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
).
masked_fill
(
mask
,
0.0
)
# (batch, head, time1, time2)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else
:
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
)
# (batch, head, time1, time2)
p_attn
=
self
.
dropout
(
attn
)
x
=
torch
.
matmul
(
p_attn
,
value
)
# (batch, head, time1, d_k)
# n_feat!=h*d_k may be happened in GroupAttention
x
=
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
n_batch
,
-
1
,
self
.
n_feat
)
)
# (batch, time1, d_model)
if
padding_q
is
not
None
:
# for GroupedAttention in efficent conformer
x
=
x
[:,
:
x
.
size
(
1
)
-
padding_q
]
return
self
.
linear_out
(
x
)
# (batch, time1, d_model)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
pos_emb
:
torch
.
Tensor
=
torch
.
empty
(
0
),
cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
)),
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q
=
self
.
linear_q
(
query
)
k
=
self
.
linear_k
(
key
)
# (#batch, time2, size)
v
=
self
.
linear_v
(
value
)
p
=
self
.
linear_pos
(
pos_emb
)
# (#batch, time2, size)
batch_size
,
seq_len_KV
,
_
=
k
.
size
()
# seq_len_KV = time2
# (#batch, time2, size) -> (#batch, head, time2, size/head)
q
=
q
.
view
(
batch_size
,
-
1
,
self
.
h
,
self
.
d_k
).
transpose
(
1
,
2
)
k
=
k
.
view
(
batch_size
,
-
1
,
self
.
h
,
self
.
d_k
).
transpose
(
1
,
2
)
v
=
v
.
view
(
batch_size
,
-
1
,
self
.
h
,
self
.
d_k
).
transpose
(
1
,
2
)
if
cache
.
size
(
0
)
>
0
:
# use attention cache
key_cache
,
value_cache
=
torch
.
split
(
cache
,
cache
.
size
(
-
1
)
//
2
,
dim
=-
1
)
k
=
torch
.
cat
([
key_cache
,
k
],
dim
=
2
)
v
=
torch
.
cat
([
value_cache
,
v
],
dim
=
2
)
new_cache
=
torch
.
cat
((
k
,
v
),
dim
=-
1
)
# May be k and p does not match. eg. time2=18+18/2=27 > mask=36/2=18
if
mask
is
not
None
and
mask
.
size
(
2
)
>
0
:
time2
=
mask
.
size
(
2
)
k
=
k
[:,
:,
-
time2
:,
:]
v
=
v
[:,
:,
-
time2
:,
:]
# q k v p: (batch, head, time1, d_k)
q
,
k
,
v
,
p
,
mask
,
padding_q
=
self
.
pad4group
(
q
,
k
,
v
,
p
,
mask
,
self
.
group_size
)
# q_with_bias_u & q_with_bias_v = (batch, head, time1, d_k)
q
=
q
.
transpose
(
1
,
2
)
# (batch, time1, head, d_k)
q_with_bias_u
=
(
q
+
self
.
pos_bias_u
).
transpose
(
1
,
2
)
q_with_bias_v
=
(
q
+
self
.
pos_bias_v
).
transpose
(
1
,
2
)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
torch
.
matmul
(
q_with_bias_u
,
k
.
transpose
(
-
2
,
-
1
))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd
=
torch
.
matmul
(
q_with_bias_v
,
p
.
transpose
(
-
2
,
-
1
))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
*
self
.
group_size
)
# (batch, head, time1, time2)
return
self
.
forward_attention
(
v
,
scores
,
mask
,
padding_q
),
new_cache
examples/aishell/s0/wenet/efficient_conformer/convolution.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2022 58.com(Wuba) Inc AI Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""ConvolutionModule definition."""
from
typing
import
Tuple
import
torch
from
torch
import
nn
from
typeguard
import
check_argument_types
class
ConvolutionModule
(
nn
.
Module
):
"""ConvolutionModule in Conformer model."""
def
__init__
(
self
,
channels
:
int
,
kernel_size
:
int
=
15
,
activation
:
nn
.
Module
=
nn
.
ReLU
(),
norm
:
str
=
"batch_norm"
,
causal
:
bool
=
False
,
bias
:
bool
=
True
,
stride
:
int
=
1
):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
stride (int): Stride Convolution, for efficient Conformer
"""
assert
check_argument_types
()
super
().
__init__
()
self
.
pointwise_conv1
=
nn
.
Conv1d
(
channels
,
2
*
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
bias
,
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if
causal
:
padding
=
0
self
.
lorder
=
kernel_size
-
1
else
:
# kernel_size should be an odd number for none causal convolution
assert
(
kernel_size
-
1
)
%
2
==
0
padding
=
(
kernel_size
-
1
)
//
2
self
.
lorder
=
0
self
.
depthwise_conv
=
nn
.
Conv1d
(
channels
,
channels
,
kernel_size
,
stride
=
stride
,
# for depthwise_conv in StrideConv
padding
=
padding
,
groups
=
channels
,
bias
=
bias
,
)
assert
norm
in
[
'batch_norm'
,
'layer_norm'
]
if
norm
==
"batch_norm"
:
self
.
use_layer_norm
=
False
self
.
norm
=
nn
.
BatchNorm1d
(
channels
)
else
:
self
.
use_layer_norm
=
True
self
.
norm
=
nn
.
LayerNorm
(
channels
)
self
.
pointwise_conv2
=
nn
.
Conv1d
(
channels
,
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
bias
,
)
self
.
activation
=
activation
self
.
stride
=
stride
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask_pad
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
)),
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x
=
x
.
transpose
(
1
,
2
)
# (#batch, channels, time)
# mask batch padding
if
mask_pad
.
size
(
2
)
>
0
:
# time > 0
x
.
masked_fill_
(
~
mask_pad
,
0.0
)
if
self
.
lorder
>
0
:
if
cache
.
size
(
2
)
==
0
:
# cache_t == 0
x
=
nn
.
functional
.
pad
(
x
,
(
self
.
lorder
,
0
),
'constant'
,
0.0
)
else
:
# When export ONNX,the first cache is not None but all-zero,
# cause shape error in residual block,
# eg. cache14 + x9 = 23, 23-7+1=17 != 9
cache
=
cache
[:,
:,
-
self
.
lorder
:]
assert
cache
.
size
(
0
)
==
x
.
size
(
0
)
# equal batch
assert
cache
.
size
(
1
)
==
x
.
size
(
1
)
# equal channel
x
=
torch
.
cat
((
cache
,
x
),
dim
=
2
)
assert
(
x
.
size
(
2
)
>
self
.
lorder
)
new_cache
=
x
[:,
:,
-
self
.
lorder
:]
else
:
# It's better we just return None if no cache is requried,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache
=
torch
.
zeros
((
0
,
0
,
0
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# GLU mechanism
x
=
self
.
pointwise_conv1
(
x
)
# (batch, 2*channel, dim)
x
=
nn
.
functional
.
glu
(
x
,
dim
=
1
)
# (batch, channel, dim)
# 1D Depthwise Conv
x
=
self
.
depthwise_conv
(
x
)
if
self
.
use_layer_norm
:
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
activation
(
self
.
norm
(
x
))
if
self
.
use_layer_norm
:
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
pointwise_conv2
(
x
)
# mask batch padding
if
mask_pad
.
size
(
2
)
>
0
:
# time > 0
if
mask_pad
.
size
(
2
)
!=
x
.
size
(
2
):
mask_pad
=
mask_pad
[:,
:,
::
self
.
stride
]
x
.
masked_fill_
(
~
mask_pad
,
0.0
)
return
x
.
transpose
(
1
,
2
),
new_cache
examples/aishell/s0/wenet/efficient_conformer/encoder.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2022 58.com(Wuba) Inc AI Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from EfficientConformer(https://github.com/burchim/EfficientConformer)
# Paper(https://arxiv.org/abs/2109.01163)
"""Encoder definition."""
from
typing
import
Tuple
,
Optional
,
List
,
Union
import
torch
import
logging
from
typeguard
import
check_argument_types
import
torch.nn.functional
as
F
from
wenet.transformer.positionwise_feed_forward
import
PositionwiseFeedForward
from
wenet.transformer.embedding
import
PositionalEncoding
from
wenet.transformer.embedding
import
RelPositionalEncoding
from
wenet.transformer.embedding
import
NoPositionalEncoding
from
wenet.transformer.subsampling
import
Conv2dSubsampling4
from
wenet.transformer.subsampling
import
Conv2dSubsampling6
from
wenet.transformer.subsampling
import
Conv2dSubsampling8
from
wenet.transformer.subsampling
import
LinearNoSubsampling
from
wenet.transformer.attention
import
MultiHeadedAttention
from
wenet.transformer.attention
import
RelPositionMultiHeadedAttention
from
wenet.transformer.encoder_layer
import
ConformerEncoderLayer
from
wenet.efficient_conformer.subsampling
import
Conv2dSubsampling2
from
wenet.efficient_conformer.convolution
import
ConvolutionModule
from
wenet.efficient_conformer.attention
import
GroupedRelPositionMultiHeadedAttention
from
wenet.efficient_conformer.encoder_layer
import
StrideConformerEncoderLayer
from
wenet.utils.common
import
get_activation
from
wenet.utils.mask
import
make_pad_mask
from
wenet.utils.mask
import
add_optional_chunk_mask
class
EfficientConformerEncoder
(
torch
.
nn
.
Module
):
"""Conformer encoder module."""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
=
256
,
attention_heads
:
int
=
4
,
linear_units
:
int
=
2048
,
num_blocks
:
int
=
6
,
dropout_rate
:
float
=
0.1
,
positional_dropout_rate
:
float
=
0.1
,
attention_dropout_rate
:
float
=
0.0
,
input_layer
:
str
=
"conv2d"
,
pos_enc_layer_type
:
str
=
"rel_pos"
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
static_chunk_size
:
int
=
0
,
use_dynamic_chunk
:
bool
=
False
,
global_cmvn
:
torch
.
nn
.
Module
=
None
,
use_dynamic_left_chunk
:
bool
=
False
,
macaron_style
:
bool
=
True
,
activation_type
:
str
=
"swish"
,
use_cnn_module
:
bool
=
True
,
cnn_module_kernel
:
int
=
15
,
causal
:
bool
=
False
,
cnn_module_norm
:
str
=
"batch_norm"
,
stride_layer_idx
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
3
,
stride
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
2
,
group_layer_idx
:
Optional
[
Union
[
int
,
List
[
int
],
tuple
]]
=
(
0
,
1
,
2
,
3
),
group_size
:
int
=
3
,
stride_kernel
:
bool
=
True
,
**
kwargs
):
"""Construct Efficient Conformer Encoder
Args:
input_size to use_dynamic_chunk, see in BaseEncoder
macaron_style (bool): Whether to use macaron style for
positionwise layer.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): whether to use causal convolution or not.
stride_layer_idx (list): layer id with StrideConv, start from 0
stride (list): stride size of each StrideConv in efficient conformer
group_layer_idx (list): layer id with GroupedAttention, start from 0
group_size (int): group size of every GroupedAttention layer
stride_kernel (bool): default True. True: recompute cnn kernels with stride.
"""
assert
check_argument_types
()
super
().
__init__
()
self
.
_output_size
=
output_size
if
pos_enc_layer_type
==
"abs_pos"
:
pos_enc_class
=
PositionalEncoding
elif
pos_enc_layer_type
==
"rel_pos"
:
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"no_pos"
:
pos_enc_class
=
NoPositionalEncoding
else
:
raise
ValueError
(
"unknown pos_enc_layer: "
+
pos_enc_layer_type
)
if
input_layer
==
"linear"
:
subsampling_class
=
LinearNoSubsampling
elif
input_layer
==
"conv2d2"
:
subsampling_class
=
Conv2dSubsampling2
elif
input_layer
==
"conv2d"
:
subsampling_class
=
Conv2dSubsampling4
elif
input_layer
==
"conv2d6"
:
subsampling_class
=
Conv2dSubsampling6
elif
input_layer
==
"conv2d8"
:
subsampling_class
=
Conv2dSubsampling8
else
:
raise
ValueError
(
"unknown input_layer: "
+
input_layer
)
logging
.
info
(
f
"input_layer =
{
input_layer
}
, "
f
"subsampling_class =
{
subsampling_class
}
"
)
self
.
global_cmvn
=
global_cmvn
self
.
embed
=
subsampling_class
(
input_size
,
output_size
,
dropout_rate
,
pos_enc_class
(
output_size
,
positional_dropout_rate
),
)
self
.
input_layer
=
input_layer
self
.
normalize_before
=
normalize_before
self
.
after_norm
=
torch
.
nn
.
LayerNorm
(
output_size
,
eps
=
1e-5
)
self
.
static_chunk_size
=
static_chunk_size
self
.
use_dynamic_chunk
=
use_dynamic_chunk
self
.
use_dynamic_left_chunk
=
use_dynamic_left_chunk
activation
=
get_activation
(
activation_type
)
self
.
num_blocks
=
num_blocks
self
.
attention_heads
=
attention_heads
self
.
cnn_module_kernel
=
cnn_module_kernel
self
.
global_chunk_size
=
0
# efficient conformer configs
self
.
stride_layer_idx
=
[
stride_layer_idx
]
\
if
type
(
stride_layer_idx
)
==
int
else
stride_layer_idx
self
.
stride
=
[
stride
]
\
if
type
(
stride
)
==
int
else
stride
self
.
group_layer_idx
=
[
group_layer_idx
]
\
if
type
(
group_layer_idx
)
==
int
else
group_layer_idx
self
.
grouped_size
=
group_size
# group size of every GroupedAttention layer
assert
len
(
self
.
stride
)
==
len
(
self
.
stride_layer_idx
)
self
.
cnn_module_kernels
=
[
cnn_module_kernel
]
# kernel size of each StridedConv
for
i
in
self
.
stride
:
if
stride_kernel
:
self
.
cnn_module_kernels
.
append
(
self
.
cnn_module_kernels
[
-
1
]
//
i
)
else
:
self
.
cnn_module_kernels
.
append
(
self
.
cnn_module_kernels
[
-
1
])
logging
.
info
(
f
"stride_layer_idx=
{
self
.
stride_layer_idx
}
, "
f
"stride =
{
self
.
stride
}
, "
f
"cnn_module_kernel =
{
self
.
cnn_module_kernels
}
, "
f
"group_layer_idx =
{
self
.
group_layer_idx
}
, "
f
"grouped_size =
{
self
.
grouped_size
}
"
)
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer_args
=
(
output_size
,
linear_units
,
dropout_rate
,
activation
,
)
# convolution module definition
convolution_layer
=
ConvolutionModule
# encoder definition
index
=
0
layers
=
[]
for
i
in
range
(
num_blocks
):
# self-attention module definition
if
i
in
self
.
group_layer_idx
:
encoder_selfattn_layer
=
GroupedRelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
,
self
.
grouped_size
)
else
:
if
pos_enc_layer_type
==
"no_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
else
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
)
# conformer module definition
if
i
in
self
.
stride_layer_idx
:
# conformer block with downsampling
convolution_layer_args_stride
=
(
output_size
,
self
.
cnn_module_kernels
[
index
],
activation
,
cnn_module_norm
,
causal
,
True
,
self
.
stride
[
index
])
layers
.
append
(
StrideConformerEncoderLayer
(
output_size
,
encoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
)
if
macaron_style
else
None
,
convolution_layer
(
*
convolution_layer_args_stride
)
if
use_cnn_module
else
None
,
torch
.
nn
.
AvgPool1d
(
kernel_size
=
self
.
stride
[
index
],
stride
=
self
.
stride
[
index
],
padding
=
0
,
ceil_mode
=
True
,
count_include_pad
=
False
),
# pointwise_conv_layer
dropout_rate
,
normalize_before
,
concat_after
,
))
index
=
index
+
1
else
:
# conformer block
convolution_layer_args_normal
=
(
output_size
,
self
.
cnn_module_kernels
[
index
],
activation
,
cnn_module_norm
,
causal
)
layers
.
append
(
ConformerEncoderLayer
(
output_size
,
encoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
)
if
macaron_style
else
None
,
convolution_layer
(
*
convolution_layer_args_normal
)
if
use_cnn_module
else
None
,
dropout_rate
,
normalize_before
,
concat_after
,
))
self
.
encoders
=
torch
.
nn
.
ModuleList
(
layers
)
def
set_global_chunk_size
(
self
,
chunk_size
):
"""Used in ONNX export.
"""
logging
.
info
(
f
"set global chunk size:
{
chunk_size
}
, default is 0."
)
self
.
global_chunk_size
=
chunk_size
def
output_size
(
self
)
->
int
:
return
self
.
_output_size
def
calculate_downsampling_factor
(
self
,
i
:
int
)
->
int
:
factor
=
1
for
idx
,
stride_idx
in
enumerate
(
self
.
stride_layer_idx
):
if
i
>
stride_idx
:
factor
*=
self
.
stride
[
idx
]
return
factor
def
forward
(
self
,
xs
:
torch
.
Tensor
,
xs_lens
:
torch
.
Tensor
,
decoding_chunk_size
:
int
=
0
,
num_decoding_left_chunks
:
int
=
-
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""
T
=
xs
.
size
(
1
)
masks
=
~
make_pad_mask
(
xs_lens
,
T
).
unsqueeze
(
1
)
# (B, 1, T)
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
)
mask_pad
=
masks
# (B, 1, T/subsample_rate)
chunk_masks
=
add_optional_chunk_mask
(
xs
,
masks
,
self
.
use_dynamic_chunk
,
self
.
use_dynamic_left_chunk
,
decoding_chunk_size
,
self
.
static_chunk_size
,
num_decoding_left_chunks
)
index
=
0
# traverse stride
for
i
,
layer
in
enumerate
(
self
.
encoders
):
# layer return : x, mask, new_att_cache, new_cnn_cache
xs
,
chunk_masks
,
_
,
_
=
layer
(
xs
,
chunk_masks
,
pos_emb
,
mask_pad
)
if
i
in
self
.
stride_layer_idx
:
masks
=
masks
[:,
:,
::
self
.
stride
[
index
]]
chunk_masks
=
chunk_masks
[:,
::
self
.
stride
[
index
],
::
self
.
stride
[
index
]]
mask_pad
=
masks
pos_emb
=
pos_emb
[:,
::
self
.
stride
[
index
],
:]
index
=
index
+
1
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return
xs
,
masks
def
forward_chunk
(
self
,
xs
:
torch
.
Tensor
,
offset
:
int
,
required_cache_size
:
int
,
att_cache
:
torch
.
Tensor
=
torch
.
zeros
(
0
,
0
,
0
,
0
),
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
(
0
,
0
,
0
,
0
),
att_mask
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
)
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
att_mask : mask matrix of self attention
Returns:
torch.Tensor: output of current input xs
torch.Tensor: subsampling cache required for next chunk computation
List[torch.Tensor]: encoder layers output cache required for next
chunk computation
List[torch.Tensor]: conformer cnn cache
"""
assert
xs
.
size
(
0
)
==
1
# using downsampling factor to recover offset
offset
*=
self
.
calculate_downsampling_factor
(
self
.
num_blocks
+
1
)
# tmp_masks is just for interface compatibility
tmp_masks
=
torch
.
ones
(
1
,
xs
.
size
(
1
),
device
=
xs
.
device
,
dtype
=
torch
.
bool
)
tmp_masks
=
tmp_masks
.
unsqueeze
(
1
)
# (1, 1, xs-time)
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs
,
pos_emb
,
_
=
self
.
embed
(
xs
,
tmp_masks
,
offset
)
elayers
,
cache_t1
=
att_cache
.
size
(
0
),
att_cache
.
size
(
2
)
chunk_size
=
xs
.
size
(
1
)
attention_key_size
=
cache_t1
+
chunk_size
pos_emb
=
self
.
embed
.
position_encoding
(
offset
=
offset
-
cache_t1
,
size
=
attention_key_size
)
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
# shape(pos_emb) = (b=1, chunk_size, emb_size=output_size=hidden-dim)
if
required_cache_size
<
0
:
next_cache_start
=
0
elif
required_cache_size
==
0
:
next_cache_start
=
attention_key_size
else
:
next_cache_start
=
max
(
attention_key_size
-
required_cache_size
,
0
)
# for ONNX export, padding xs to chunk_size
if
self
.
global_chunk_size
>
0
:
real_len
=
xs
.
size
(
1
)
xs
=
F
.
pad
(
xs
,
(
0
,
0
,
0
,
self
.
global_chunk_size
-
real_len
),
value
=
0.0
)
tmp_zeros
=
torch
.
zeros
(
att_mask
.
shape
,
dtype
=
torch
.
bool
)
att_mask
[:,
:,
required_cache_size
+
real_len
+
1
:]
=
\
tmp_zeros
[:,
:,
required_cache_size
+
real_len
+
1
:]
r_att_cache
=
[]
r_cnn_cache
=
[]
mask_pad
=
torch
.
ones
(
1
,
xs
.
size
(
1
),
device
=
xs
.
device
,
dtype
=
torch
.
bool
)
mask_pad
=
mask_pad
.
unsqueeze
(
1
)
# batchPad (b=1, 1, time=chunk_size)
max_att_len
,
max_cnn_len
=
0
,
0
# for repeat_interleave of new_att_cache
for
i
,
layer
in
enumerate
(
self
.
encoders
):
factor
=
self
.
calculate_downsampling_factor
(
i
)
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
# shape(new_att_cache) = [ batch, head, time2, outdim//head * 2 ]
xs
,
_
,
new_att_cache
,
new_cnn_cache
=
layer
(
xs
,
att_mask
,
pos_emb
,
mask_pad
=
mask_pad
,
att_cache
=
att_cache
[
i
:
i
+
1
,
:,
::
factor
,
:],
cnn_cache
=
cnn_cache
[
i
,
:,
:,
:]
if
cnn_cache
.
size
(
0
)
>
0
else
cnn_cache
)
if
i
in
self
.
stride_layer_idx
:
# compute time dimension for next block
efficient_index
=
self
.
stride_layer_idx
.
index
(
i
)
att_mask
=
att_mask
[:,
::
self
.
stride
[
efficient_index
],
::
self
.
stride
[
efficient_index
]]
mask_pad
=
mask_pad
[:,
::
self
.
stride
[
efficient_index
],
::
self
.
stride
[
efficient_index
]]
pos_emb
=
pos_emb
[:,
::
self
.
stride
[
efficient_index
],
:]
# shape(new_att_cache) = [batch, head, time2, outdim]
new_att_cache
=
new_att_cache
[:,
:,
next_cache_start
//
factor
:,
:]
# shape(new_cnn_cache) = [1, batch, outdim, cache_t2]
new_cnn_cache
=
new_cnn_cache
.
unsqueeze
(
0
)
# use repeat_interleave to new_att_cache
new_att_cache
=
new_att_cache
.
repeat_interleave
(
repeats
=
factor
,
dim
=
2
)
# padding new_cnn_cache to cnn.lorder for casual convolution
new_cnn_cache
=
F
.
pad
(
new_cnn_cache
,
(
self
.
cnn_module_kernel
-
1
-
new_cnn_cache
.
size
(
3
),
0
))
if
i
==
0
:
# record length for the first block as max length
max_att_len
=
new_att_cache
.
size
(
2
)
max_cnn_len
=
new_cnn_cache
.
size
(
3
)
# update real shape of att_cache and cnn_cache
r_att_cache
.
append
(
new_att_cache
[:,
:,
-
max_att_len
:,
:])
r_cnn_cache
.
append
(
new_cnn_cache
[:,
:,
:,
-
max_cnn_len
:])
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
# ? may be larger than cache_t1, it depends on required_cache_size
r_att_cache
=
torch
.
cat
(
r_att_cache
,
dim
=
0
)
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
r_cnn_cache
=
torch
.
cat
(
r_cnn_cache
,
dim
=
0
)
return
xs
,
r_att_cache
,
r_cnn_cache
def
forward_chunk_by_chunk
(
self
,
xs
:
torch
.
Tensor
,
decoding_chunk_size
:
int
,
num_decoding_left_chunks
:
int
=
-
1
,
use_onnx
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
""" Forward input chunk by chunk with chunk_size like a streaming
fashion
Here we should pay special attention to computation cache in the
streaming style forward chunk by chunk. Three things should be taken
into account for computation in the current network:
1. transformer/conformer encoder layers output cache
2. convolution in conformer
3. convolution in subsampling
However, we don't implement subsampling cache for:
1. We can control subsampling module to output the right result by
overlapping input instead of cache left context, even though it
wastes some computation, but subsampling only takes a very
small fraction of computation in the whole model.
2. Typically, there are several covolution layers with subsampling
in subsampling module, it is tricky and complicated to do cache
with different convolution layers with different subsampling
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
Args:
xs (torch.Tensor): (1, max_len, dim)
decoding_chunk_size (int): decoding chunk size
num_decoding_left_chunks (int):
use_onnx (bool): True for simulating ONNX model inference.
"""
assert
decoding_chunk_size
>
0
# The model is trained by static or dynamic chunk
assert
self
.
static_chunk_size
>
0
or
self
.
use_dynamic_chunk
subsampling
=
self
.
embed
.
subsampling_rate
context
=
self
.
embed
.
right_context
+
1
# Add current frame
stride
=
subsampling
*
decoding_chunk_size
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
num_frames
=
xs
.
size
(
1
)
outputs
=
[]
offset
=
0
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
if
use_onnx
:
logging
.
info
(
"Simulating for ONNX runtime ..."
)
att_cache
:
torch
.
Tensor
=
torch
.
zeros
(
(
self
.
num_blocks
,
self
.
attention_heads
,
required_cache_size
,
self
.
output_size
()
//
self
.
attention_heads
*
2
),
device
=
xs
.
device
)
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
(
(
self
.
num_blocks
,
1
,
self
.
output_size
(),
self
.
cnn_module_kernel
-
1
),
device
=
xs
.
device
)
self
.
set_global_chunk_size
(
chunk_size
=
18
)
else
:
logging
.
info
(
"Simulating for JIT runtime ..."
)
att_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
),
device
=
xs
.
device
)
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
),
device
=
xs
.
device
)
# Feed forward overlap input step by step
for
cur
in
range
(
0
,
num_frames
-
context
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
logging
.
info
(
f
"-->> frame chunk msg: cur=
{
cur
}
, "
f
"end=
{
end
}
, num_frames=
{
end
-
cur
}
, "
f
"decoding_window=
{
decoding_window
}
"
)
if
use_onnx
:
att_mask
:
torch
.
Tensor
=
torch
.
ones
(
(
1
,
1
,
required_cache_size
+
decoding_chunk_size
),
dtype
=
torch
.
bool
,
device
=
xs
.
device
)
if
cur
==
0
:
att_mask
[:,
:,
:
required_cache_size
]
=
0
else
:
att_mask
:
torch
.
Tensor
=
torch
.
ones
(
(
0
,
0
,
0
),
dtype
=
torch
.
bool
,
device
=
xs
.
device
)
chunk_xs
=
xs
[:,
cur
:
end
,
:]
(
y
,
att_cache
,
cnn_cache
)
=
\
self
.
forward_chunk
(
chunk_xs
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
,
att_mask
)
outputs
.
append
(
y
)
offset
+=
y
.
size
(
1
)
ys
=
torch
.
cat
(
outputs
,
1
)
masks
=
torch
.
ones
(
1
,
1
,
ys
.
size
(
1
),
device
=
ys
.
device
,
dtype
=
torch
.
bool
)
return
ys
,
masks
examples/aishell/s0/wenet/efficient_conformer/encoder_layer.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2022 58.com(Wuba) Inc AI Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder self-attention layer definition."""
from
typing
import
Optional
,
Tuple
import
torch
from
torch
import
nn
class
StrideConformerEncoderLayer
(
nn
.
Module
):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def
__init__
(
self
,
size
:
int
,
self_attn
:
torch
.
nn
.
Module
,
feed_forward
:
Optional
[
nn
.
Module
]
=
None
,
feed_forward_macaron
:
Optional
[
nn
.
Module
]
=
None
,
conv_module
:
Optional
[
nn
.
Module
]
=
None
,
pointwise_conv_layer
:
Optional
[
nn
.
Module
]
=
None
,
dropout_rate
:
float
=
0.1
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
):
"""Construct an EncoderLayer object."""
super
().
__init__
()
self
.
self_attn
=
self_attn
self
.
feed_forward
=
feed_forward
self
.
feed_forward_macaron
=
feed_forward_macaron
self
.
conv_module
=
conv_module
self
.
pointwise_conv_layer
=
pointwise_conv_layer
self
.
norm_ff
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
# for the FNN module
self
.
norm_mha
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
# for the MHA module
if
feed_forward_macaron
is
not
None
:
self
.
norm_ff_macaron
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
self
.
ff_scale
=
0.5
else
:
self
.
ff_scale
=
1.0
if
self
.
conv_module
is
not
None
:
self
.
norm_conv
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
# for the CNN module
self
.
norm_final
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
# for the final output of the block
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
pos_emb
:
torch
.
Tensor
,
mask_pad
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
att_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
)),
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
)),
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# whether to use macaron style
if
self
.
feed_forward_macaron
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward_macaron
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
# multi-headed self-attention module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_mha
(
x
)
x_att
,
new_att_cache
=
self
.
self_attn
(
x
,
x
,
x
,
mask
,
pos_emb
,
att_cache
)
if
self
.
concat_after
:
x_concat
=
torch
.
cat
((
x
,
x_att
),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
x_att
)
if
not
self
.
normalize_before
:
x
=
self
.
norm_mha
(
x
)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache
=
torch
.
tensor
([
0.0
],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
if
self
.
conv_module
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
x
,
new_cnn_cache
=
self
.
conv_module
(
x
,
mask_pad
,
cnn_cache
)
# add pointwise_conv for efficient conformer
# pointwise_conv_layer does not change shape
if
self
.
pointwise_conv_layer
is
not
None
:
residual
=
residual
.
transpose
(
1
,
2
)
residual
=
self
.
pointwise_conv_layer
(
residual
)
residual
=
residual
.
transpose
(
1
,
2
)
assert
residual
.
size
(
0
)
==
x
.
size
(
0
)
assert
residual
.
size
(
1
)
==
x
.
size
(
1
)
assert
residual
.
size
(
2
)
==
x
.
size
(
2
)
x
=
residual
+
self
.
dropout
(
x
)
if
not
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
# feed forward module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff
(
x
)
if
self
.
conv_module
is
not
None
:
x
=
self
.
norm_final
(
x
)
return
x
,
mask
,
new_att_cache
,
new_cnn_cache
examples/aishell/s0/wenet/efficient_conformer/subsampling.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 58.com(Wuba) Inc AI Lab.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Subsampling layer definition."""
from
typing
import
Tuple
,
Union
import
torch
from
wenet.transformer.subsampling
import
BaseSubsampling
class
Conv2dSubsampling2
(
BaseSubsampling
):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
torch
.
nn
.
Module
):
"""Construct an Conv2dSubsampling4 object."""
super
().
__init__
()
self
.
conv
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
1
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
()
)
self
.
out
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
odim
*
((
idim
-
1
)
//
2
),
odim
))
self
.
pos_enc
=
pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self
.
subsampling_rate
=
2
# 2 = (3 - 1) * 1
self
.
right_context
=
2
def
forward
(
self
,
x
:
torch
.
Tensor
,
x_mask
:
torch
.
Tensor
,
offset
:
Union
[
int
,
torch
.
Tensor
]
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
torch.Tensor: positional encoding
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c=1, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
x
.
size
()
x
=
self
.
out
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
))
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
[:,
:,
:
-
2
:
2
]
examples/aishell/s0/wenet/squeezeformer/__pycache__/attention.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/squeezeformer/__pycache__/conv2d.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/squeezeformer/__pycache__/convolution.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/squeezeformer/__pycache__/encoder.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/squeezeformer/__pycache__/encoder_layer.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/squeezeformer/__pycache__/positionwise_feed_forward.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/squeezeformer/__pycache__/subsampling.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/squeezeformer/attention.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2022 Ximalaya Inc. (Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Head Attention layer definition."""
import
math
import
torch
import
torch.nn
as
nn
from
wenet.transformer.attention
import
MultiHeadedAttention
from
typing
import
Tuple
class
RelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
,
do_rel_shift
=
False
,
adaptive_scale
=
False
,
init_weights
=
False
):
"""Construct an RelPositionMultiHeadedAttention object."""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
# linear transformation for positional encoding
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias
=
False
)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self
.
do_rel_shift
=
do_rel_shift
self
.
pos_bias_u
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
))
self
.
pos_bias_v
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
))
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_u
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_v
)
self
.
adaptive_scale
=
adaptive_scale
self
.
ada_scale
=
nn
.
Parameter
(
torch
.
ones
([
1
,
1
,
n_feat
]),
requires_grad
=
adaptive_scale
)
self
.
ada_bias
=
nn
.
Parameter
(
torch
.
zeros
([
1
,
1
,
n_feat
]),
requires_grad
=
adaptive_scale
)
if
init_weights
:
self
.
init_weights
()
def
init_weights
(
self
):
input_max
=
(
self
.
h
*
self
.
d_k
)
**
-
0.5
torch
.
nn
.
init
.
uniform_
(
self
.
linear_q
.
weight
,
-
input_max
,
input_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
linear_q
.
bias
,
-
input_max
,
input_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
linear_k
.
weight
,
-
input_max
,
input_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
linear_k
.
bias
,
-
input_max
,
input_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
linear_v
.
weight
,
-
input_max
,
input_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
linear_v
.
bias
,
-
input_max
,
input_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
linear_pos
.
weight
,
-
input_max
,
input_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
linear_out
.
weight
,
-
input_max
,
input_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
linear_out
.
bias
,
-
input_max
,
input_max
)
def
rel_shift
(
self
,
x
,
zero_triu
:
bool
=
False
):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad
=
torch
.
zeros
((
x
.
size
()[
0
],
x
.
size
()[
1
],
x
.
size
()[
2
],
1
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=-
1
)
x_padded
=
x_padded
.
view
(
x
.
size
()[
0
],
x
.
size
()[
1
],
x
.
size
(
3
)
+
1
,
x
.
size
(
2
))
x
=
x_padded
[:,
:,
1
:].
view_as
(
x
)
if
zero_triu
:
ones
=
torch
.
ones
((
x
.
size
(
2
),
x
.
size
(
3
)))
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
3
)
-
x
.
size
(
2
))[
None
,
None
,
:,
:]
return
x
def
forward_attention
(
self
,
value
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
)
)
->
torch
.
Tensor
:
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch
=
value
.
size
(
0
)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if
mask
.
size
(
2
)
>
0
:
# time2 > 0
mask
=
mask
.
unsqueeze
(
1
).
eq
(
0
)
# (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask
=
mask
[:,
:,
:,
:
scores
.
size
(
-
1
)]
# (batch, 1, *, time2)
scores
=
scores
.
masked_fill
(
mask
,
-
float
(
'inf'
))
# (batch, head, time1, time2)
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
).
masked_fill
(
mask
,
0.0
)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else
:
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
)
# (batch, head, time1, time2)
p_attn
=
self
.
dropout
(
attn
)
x
=
torch
.
matmul
(
p_attn
,
value
)
# (batch, head, time1, d_k)
x
=
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
n_batch
,
-
1
,
self
.
h
*
self
.
d_k
)
)
# (batch, time1, d_model)
return
self
.
linear_out
(
x
)
# (batch, time1, d_model)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
pos_emb
:
torch
.
Tensor
=
torch
.
empty
(
0
),
cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
))
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
if
self
.
adaptive_scale
:
query
=
self
.
ada_scale
*
query
+
self
.
ada_bias
key
=
self
.
ada_scale
*
key
+
self
.
ada_bias
value
=
self
.
ada_scale
*
value
+
self
.
ada_bias
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
q
=
q
.
transpose
(
1
,
2
)
# (batch, time1, head, d_k)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if
cache
.
size
(
0
)
>
0
:
key_cache
,
value_cache
=
torch
.
split
(
cache
,
cache
.
size
(
-
1
)
//
2
,
dim
=-
1
)
k
=
torch
.
cat
([
key_cache
,
k
],
dim
=
2
)
v
=
torch
.
cat
([
value_cache
,
v
],
dim
=
2
)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache
=
torch
.
cat
((
k
,
v
),
dim
=-
1
)
n_batch_pos
=
pos_emb
.
size
(
0
)
p
=
self
.
linear_pos
(
pos_emb
).
view
(
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
)
p
=
p
.
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u
=
(
q
+
self
.
pos_bias_u
).
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
q_with_bias_v
=
(
q
+
self
.
pos_bias_v
).
transpose
(
1
,
2
)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
torch
.
matmul
(
q_with_bias_u
,
k
.
transpose
(
-
2
,
-
1
))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd
=
torch
.
matmul
(
q_with_bias_v
,
p
.
transpose
(
-
2
,
-
1
))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
if
self
.
do_rel_shift
:
matrix_bd
=
self
.
rel_shift
(
matrix_bd
)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
# (batch, head, time1, time2)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
examples/aishell/s0/wenet/squeezeformer/conv2d.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Conv2d Module with Valid Padding"""
import
torch.nn.functional
as
F
from
torch.nn.modules.conv
import
_ConvNd
,
_size_2_t
,
Union
,
_pair
,
Tensor
,
Optional
class
Conv2dValid
(
_ConvNd
):
"""
Conv2d operator for VALID mode padding.
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
_size_2_t
,
stride
:
_size_2_t
=
1
,
padding
:
Union
[
str
,
_size_2_t
]
=
0
,
dilation
:
_size_2_t
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
padding_mode
:
str
=
'zeros'
,
# TODO: refine this type
device
=
None
,
dtype
=
None
,
valid_trigx
:
bool
=
False
,
valid_trigy
:
bool
=
False
)
->
None
:
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
kernel_size_
=
_pair
(
kernel_size
)
stride_
=
_pair
(
stride
)
padding_
=
padding
if
isinstance
(
padding
,
str
)
else
_pair
(
padding
)
dilation_
=
_pair
(
dilation
)
super
(
Conv2dValid
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size_
,
stride_
,
padding_
,
dilation_
,
False
,
_pair
(
0
),
groups
,
bias
,
padding_mode
,
**
factory_kwargs
)
self
.
valid_trigx
=
valid_trigx
self
.
valid_trigy
=
valid_trigy
def
_conv_forward
(
self
,
input
:
Tensor
,
weight
:
Tensor
,
bias
:
Optional
[
Tensor
]):
validx
,
validy
=
0
,
0
if
self
.
valid_trigx
:
validx
=
(
input
.
size
(
-
2
)
*
(
self
.
stride
[
-
2
]
-
1
)
-
1
+
self
.
kernel_size
[
-
2
])
//
2
if
self
.
valid_trigy
:
validy
=
(
input
.
size
(
-
1
)
*
(
self
.
stride
[
-
1
]
-
1
)
-
1
+
self
.
kernel_size
[
-
1
])
//
2
return
F
.
conv2d
(
input
,
weight
,
bias
,
self
.
stride
,
(
validx
,
validy
),
self
.
dilation
,
self
.
groups
)
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
return
self
.
_conv_forward
(
input
,
self
.
weight
,
self
.
bias
)
examples/aishell/s0/wenet/squeezeformer/convolution.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2022 Ximalaya Inc. (authors: Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""ConvolutionModule definition."""
from
typing
import
Tuple
import
torch
from
torch
import
nn
from
typeguard
import
check_argument_types
class
ConvolutionModule
(
nn
.
Module
):
"""ConvolutionModule in Conformer model."""
def
__init__
(
self
,
channels
:
int
,
kernel_size
:
int
=
15
,
activation
:
nn
.
Module
=
nn
.
ReLU
(),
norm
:
str
=
"batch_norm"
,
causal
:
bool
=
False
,
bias
:
bool
=
True
,
adaptive_scale
:
bool
=
False
,
init_weights
:
bool
=
False
):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
assert
check_argument_types
()
super
().
__init__
()
self
.
bias
=
bias
self
.
channels
=
channels
self
.
kernel_size
=
kernel_size
self
.
adaptive_scale
=
adaptive_scale
self
.
ada_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
([
1
,
1
,
channels
]),
requires_grad
=
adaptive_scale
)
self
.
ada_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
([
1
,
1
,
channels
]),
requires_grad
=
adaptive_scale
)
self
.
pointwise_conv1
=
nn
.
Conv1d
(
channels
,
2
*
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
bias
,
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if
causal
:
padding
=
0
self
.
lorder
=
kernel_size
-
1
else
:
# kernel_size should be an odd number for none causal convolution
assert
(
kernel_size
-
1
)
%
2
==
0
padding
=
(
kernel_size
-
1
)
//
2
self
.
lorder
=
0
self
.
depthwise_conv
=
nn
.
Conv1d
(
channels
,
channels
,
kernel_size
,
stride
=
1
,
padding
=
padding
,
groups
=
channels
,
bias
=
bias
,
)
assert
norm
in
[
'batch_norm'
,
'layer_norm'
]
if
norm
==
"batch_norm"
:
self
.
use_layer_norm
=
False
self
.
norm
=
nn
.
BatchNorm1d
(
channels
)
else
:
self
.
use_layer_norm
=
True
self
.
norm
=
nn
.
LayerNorm
(
channels
)
self
.
pointwise_conv2
=
nn
.
Conv1d
(
channels
,
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
bias
,
)
self
.
activation
=
activation
if
init_weights
:
self
.
init_weights
()
def
init_weights
(
self
):
pw_max
=
self
.
channels
**
-
0.5
dw_max
=
self
.
kernel_size
**
-
0.5
torch
.
nn
.
init
.
uniform_
(
self
.
pointwise_conv1
.
weight
.
data
,
-
pw_max
,
pw_max
)
if
self
.
bias
:
torch
.
nn
.
init
.
uniform_
(
self
.
pointwise_conv1
.
bias
.
data
,
-
pw_max
,
pw_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
depthwise_conv
.
weight
.
data
,
-
dw_max
,
dw_max
)
if
self
.
bias
:
torch
.
nn
.
init
.
uniform_
(
self
.
depthwise_conv
.
bias
.
data
,
-
dw_max
,
dw_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
pointwise_conv2
.
weight
.
data
,
-
pw_max
,
pw_max
)
if
self
.
bias
:
torch
.
nn
.
init
.
uniform_
(
self
.
pointwise_conv2
.
bias
.
data
,
-
pw_max
,
pw_max
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask_pad
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
)),
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
if
self
.
adaptive_scale
:
x
=
self
.
ada_scale
*
x
+
self
.
ada_bias
# exchange the temporal dimension and the feature dimension
x
=
x
.
transpose
(
1
,
2
)
# (#batch, channels, time)
# mask batch padding
if
mask_pad
.
size
(
2
)
>
0
:
# time > 0
x
.
masked_fill_
(
~
mask_pad
,
0.0
)
if
self
.
lorder
>
0
:
if
cache
.
size
(
2
)
==
0
:
# cache_t == 0
x
=
nn
.
functional
.
pad
(
x
,
(
self
.
lorder
,
0
),
'constant'
,
0.0
)
else
:
assert
cache
.
size
(
0
)
==
x
.
size
(
0
)
# equal batch
assert
cache
.
size
(
1
)
==
x
.
size
(
1
)
# equal channel
x
=
torch
.
cat
((
cache
,
x
),
dim
=
2
)
assert
(
x
.
size
(
2
)
>
self
.
lorder
)
new_cache
=
x
[:,
:,
-
self
.
lorder
:]
else
:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache
=
torch
.
zeros
((
0
,
0
,
0
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# GLU mechanism
x
=
self
.
pointwise_conv1
(
x
)
# (batch, 2*channel, dim)
x
=
nn
.
functional
.
glu
(
x
,
dim
=
1
)
# (batch, channel, dim)
# 1D Depthwise Conv
x
=
self
.
depthwise_conv
(
x
)
if
self
.
use_layer_norm
:
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
activation
(
self
.
norm
(
x
))
if
self
.
use_layer_norm
:
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
pointwise_conv2
(
x
)
# mask batch padding
if
mask_pad
.
size
(
2
)
>
0
:
# time > 0
x
.
masked_fill_
(
~
mask_pad
,
0.0
)
return
x
.
transpose
(
1
,
2
),
new_cache
examples/aishell/s0/wenet/squeezeformer/encoder.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer)
# Squeezeformer(https://github.com/upskyy/Squeezeformer)
# NeMo(https://github.com/NVIDIA/NeMo)
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
Union
,
Optional
,
List
from
wenet.squeezeformer.subsampling
\
import
DepthwiseConv2dSubsampling4
,
TimeReductionLayer1D
,
\
TimeReductionLayer2D
,
TimeReductionLayerStream
from
wenet.squeezeformer.encoder_layer
import
SqueezeformerEncoderLayer
from
wenet.transformer.embedding
import
RelPositionalEncoding
from
wenet.transformer.attention
import
MultiHeadedAttention
from
wenet.squeezeformer.attention
import
RelPositionMultiHeadedAttention
from
wenet.squeezeformer.positionwise_feed_forward
\
import
PositionwiseFeedForward
from
wenet.squeezeformer.convolution
import
ConvolutionModule
from
wenet.utils.mask
import
make_pad_mask
,
add_optional_chunk_mask
from
wenet.utils.common
import
get_activation
class
SqueezeformerEncoder
(
nn
.
Module
):
def
__init__
(
self
,
input_size
:
int
=
80
,
encoder_dim
:
int
=
256
,
output_size
:
int
=
256
,
attention_heads
:
int
=
4
,
num_blocks
:
int
=
12
,
reduce_idx
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
5
,
recover_idx
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
11
,
feed_forward_expansion_factor
:
int
=
4
,
dw_stride
:
bool
=
False
,
input_dropout_rate
:
float
=
0.1
,
pos_enc_layer_type
:
str
=
"rel_pos"
,
time_reduction_layer_type
:
str
=
"conv1d"
,
do_rel_shift
:
bool
=
True
,
feed_forward_dropout_rate
:
float
=
0.1
,
attention_dropout_rate
:
float
=
0.1
,
cnn_module_kernel
:
int
=
31
,
cnn_norm_type
:
str
=
"batch_norm"
,
dropout
:
float
=
0.1
,
causal
:
bool
=
False
,
adaptive_scale
:
bool
=
True
,
activation_type
:
str
=
"swish"
,
init_weights
:
bool
=
True
,
global_cmvn
:
torch
.
nn
.
Module
=
None
,
normalize_before
:
bool
=
False
,
use_dynamic_chunk
:
bool
=
False
,
concat_after
:
bool
=
False
,
static_chunk_size
:
int
=
0
,
use_dynamic_left_chunk
:
bool
=
False
):
"""Construct SqueezeformerEncoder
Args:
input_size to use_dynamic_chunk, see in Transformer BaseEncoder.
encoder_dim (int): The hidden dimension of encoder layer.
output_size (int): The output dimension of final projection layer.
attention_heads (int): Num of attention head in attention module.
num_blocks (int): Num of encoder layers.
reduce_idx Optional[Union[int, List[int]]]:
reduce layer index, from 40ms to 80ms per frame.
recover_idx Optional[Union[int, List[int]]]:
recover layer index, from 80ms to 40ms per frame.
feed_forward_expansion_factor (int): Enlarge coefficient of FFN.
dw_stride (bool): Whether do depthwise convolution
on subsampling module.
input_dropout_rate (float): Dropout rate of input projection layer.
pos_enc_layer_type (str): Self attention type.
time_reduction_layer_type (str): Conv1d or Conv2d reduction layer.
do_rel_shift (bool): Whether to do relative shift
operation on rel-attention module.
cnn_module_kernel (int): Kernel size of CNN module.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
cnn_module_kernel (int): Kernel size of convolution module.
adaptive_scale (bool): Whether to use adaptive scale.
init_weights (bool): Whether to initialize weights.
causal (bool): whether to use causal convolution or not.
"""
super
(
SqueezeformerEncoder
,
self
).
__init__
()
self
.
global_cmvn
=
global_cmvn
self
.
reduce_idx
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
[
reduce_idx
]
\
if
type
(
reduce_idx
)
==
int
else
reduce_idx
self
.
recover_idx
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
[
recover_idx
]
\
if
type
(
recover_idx
)
==
int
else
recover_idx
self
.
check_ascending_list
()
if
reduce_idx
is
None
:
self
.
time_reduce
=
None
else
:
if
recover_idx
is
None
:
self
.
time_reduce
=
'normal'
# no recovery at the end
else
:
self
.
time_reduce
=
'recover'
# recovery at the end
assert
len
(
self
.
reduce_idx
)
==
len
(
self
.
recover_idx
)
self
.
reduce_stride
=
2
self
.
_output_size
=
output_size
self
.
normalize_before
=
normalize_before
self
.
static_chunk_size
=
static_chunk_size
self
.
use_dynamic_chunk
=
use_dynamic_chunk
self
.
use_dynamic_left_chunk
=
use_dynamic_left_chunk
self
.
pos_enc_layer_type
=
pos_enc_layer_type
activation
=
get_activation
(
activation_type
)
# self-attention module definition
if
pos_enc_layer_type
!=
"rel_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
,
)
else
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
encoder_dim
,
attention_dropout_rate
,
do_rel_shift
,
adaptive_scale
,
init_weights
)
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer_args
=
(
encoder_dim
,
encoder_dim
*
feed_forward_expansion_factor
,
feed_forward_dropout_rate
,
activation
,
adaptive_scale
,
init_weights
)
# convolution module definition
convolution_layer
=
ConvolutionModule
convolution_layer_args
=
(
encoder_dim
,
cnn_module_kernel
,
activation
,
cnn_norm_type
,
causal
,
True
,
adaptive_scale
,
init_weights
)
self
.
embed
=
DepthwiseConv2dSubsampling4
(
1
,
encoder_dim
,
RelPositionalEncoding
(
encoder_dim
,
dropout_rate
=
0.1
),
dw_stride
,
input_size
,
input_dropout_rate
,
init_weights
)
self
.
preln
=
nn
.
LayerNorm
(
encoder_dim
)
self
.
encoders
=
torch
.
nn
.
ModuleList
([
SqueezeformerEncoderLayer
(
encoder_dim
,
encoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
),
convolution_layer
(
*
convolution_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
),
normalize_before
,
dropout
,
concat_after
)
for
_
in
range
(
num_blocks
)
])
if
time_reduction_layer_type
==
'conv1d'
:
time_reduction_layer
=
TimeReductionLayer1D
time_reduction_layer_args
=
{
'channel'
:
encoder_dim
,
'out_dim'
:
encoder_dim
,
}
elif
time_reduction_layer_type
==
'stream'
:
time_reduction_layer
=
TimeReductionLayerStream
time_reduction_layer_args
=
{
'channel'
:
encoder_dim
,
'out_dim'
:
encoder_dim
,
}
else
:
time_reduction_layer
=
TimeReductionLayer2D
time_reduction_layer_args
=
{
'encoder_dim'
:
encoder_dim
}
self
.
time_reduction_layer
=
time_reduction_layer
(
**
time_reduction_layer_args
)
self
.
time_recover_layer
=
nn
.
Linear
(
encoder_dim
,
encoder_dim
)
self
.
final_proj
=
None
if
output_size
!=
encoder_dim
:
self
.
final_proj
=
nn
.
Linear
(
encoder_dim
,
output_size
)
def
output_size
(
self
)
->
int
:
return
self
.
_output_size
def
forward
(
self
,
xs
:
torch
.
Tensor
,
xs_lens
:
torch
.
Tensor
,
decoding_chunk_size
:
int
=
0
,
num_decoding_left_chunks
:
int
=
-
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
T
=
xs
.
size
(
1
)
masks
=
~
make_pad_mask
(
xs_lens
,
T
).
unsqueeze
(
1
)
# (B, 1, T)
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
)
mask_pad
=
masks
# (B, 1, T/subsample_rate)
chunk_masks
=
add_optional_chunk_mask
(
xs
,
masks
,
self
.
use_dynamic_chunk
,
self
.
use_dynamic_left_chunk
,
decoding_chunk_size
,
self
.
static_chunk_size
,
num_decoding_left_chunks
)
xs_lens
=
mask_pad
.
squeeze
(
1
).
sum
(
1
)
xs
=
self
.
preln
(
xs
)
recover_activations
:
\
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
index
=
0
for
i
,
layer
in
enumerate
(
self
.
encoders
):
if
self
.
reduce_idx
is
not
None
:
if
self
.
time_reduce
is
not
None
and
i
in
self
.
reduce_idx
:
recover_activations
.
append
((
xs
,
chunk_masks
,
pos_emb
,
mask_pad
))
xs
,
xs_lens
,
chunk_masks
,
mask_pad
=
\
self
.
time_reduction_layer
(
xs
,
xs_lens
,
chunk_masks
,
mask_pad
)
pos_emb
=
pos_emb
[:,
::
2
,
:]
index
+=
1
if
self
.
recover_idx
is
not
None
:
if
self
.
time_reduce
==
'recover'
and
i
in
self
.
recover_idx
:
index
-=
1
(
recover_tensor
,
recover_chunk_masks
,
recover_pos_emb
,
recover_mask_pad
)
\
=
recover_activations
[
index
]
# recover output length for ctc decode
xs
=
xs
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
2
,
1
).
flatten
(
1
,
2
)
xs
=
self
.
time_recover_layer
(
xs
)
recoverd_t
=
recover_tensor
.
size
(
1
)
xs
=
recover_tensor
+
xs
[:,
:
recoverd_t
,
:].
contiguous
()
chunk_masks
=
recover_chunk_masks
pos_emb
=
recover_pos_emb
mask_pad
=
recover_mask_pad
xs
=
xs
.
masked_fill
(
~
mask_pad
[:,
0
,
:].
unsqueeze
(
-
1
),
0.0
)
xs
,
chunk_masks
,
_
,
_
=
layer
(
xs
,
chunk_masks
,
pos_emb
,
mask_pad
)
if
self
.
final_proj
is
not
None
:
xs
=
self
.
final_proj
(
xs
)
return
xs
,
masks
def
check_ascending_list
(
self
):
if
self
.
reduce_idx
is
not
None
:
assert
self
.
reduce_idx
==
sorted
(
self
.
reduce_idx
),
\
"reduce_idx should be int or ascending list"
if
self
.
recover_idx
is
not
None
:
assert
self
.
recover_idx
==
sorted
(
self
.
recover_idx
),
\
"recover_idx should be int or ascending list"
def
calculate_downsampling_factor
(
self
,
i
:
int
)
->
int
:
if
self
.
reduce_idx
is
None
:
return
1
else
:
reduce_exp
,
recover_exp
=
0
,
0
for
exp
,
rd_idx
in
enumerate
(
self
.
reduce_idx
):
if
i
>=
rd_idx
:
reduce_exp
=
exp
+
1
if
self
.
recover_idx
is
not
None
:
for
exp
,
rc_idx
in
enumerate
(
self
.
recover_idx
):
if
i
>=
rc_idx
:
recover_exp
=
exp
+
1
return
int
(
2
**
(
reduce_exp
-
recover_exp
))
def
forward_chunk
(
self
,
xs
:
torch
.
Tensor
,
offset
:
int
,
required_cache_size
:
int
,
att_cache
:
torch
.
Tensor
=
torch
.
zeros
(
0
,
0
,
0
,
0
),
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
(
0
,
0
,
0
,
0
),
att_mask
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate +
\
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, ?, d_k * 2)
depending on required_cache_size.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
assert
xs
.
size
(
0
)
==
1
# tmp_masks is just for interface compatibility
tmp_masks
=
torch
.
ones
(
1
,
xs
.
size
(
1
),
device
=
xs
.
device
,
dtype
=
torch
.
bool
)
tmp_masks
=
tmp_masks
.
unsqueeze
(
1
)
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs
,
pos_emb
,
_
=
self
.
embed
(
xs
,
tmp_masks
,
offset
)
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
elayers
,
cache_t1
=
att_cache
.
size
(
0
),
att_cache
.
size
(
2
)
chunk_size
=
xs
.
size
(
1
)
attention_key_size
=
cache_t1
+
chunk_size
pos_emb
=
self
.
embed
.
position_encoding
(
offset
=
offset
-
cache_t1
,
size
=
attention_key_size
)
if
required_cache_size
<
0
:
next_cache_start
=
0
elif
required_cache_size
==
0
:
next_cache_start
=
attention_key_size
else
:
next_cache_start
=
max
(
attention_key_size
-
required_cache_size
,
0
)
r_att_cache
=
[]
r_cnn_cache
=
[]
mask_pad
=
torch
.
ones
(
1
,
xs
.
size
(
1
),
device
=
xs
.
device
,
dtype
=
torch
.
bool
)
mask_pad
=
mask_pad
.
unsqueeze
(
1
)
max_att_len
:
int
=
0
recover_activations
:
\
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
index
=
0
xs_lens
=
torch
.
tensor
([
xs
.
size
(
1
)],
device
=
xs
.
device
,
dtype
=
torch
.
int
)
xs
=
self
.
preln
(
xs
)
for
i
,
layer
in
enumerate
(
self
.
encoders
):
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
if
self
.
reduce_idx
is
not
None
:
if
self
.
time_reduce
is
not
None
and
i
in
self
.
reduce_idx
:
recover_activations
.
append
((
xs
,
att_mask
,
pos_emb
,
mask_pad
))
xs
,
xs_lens
,
att_mask
,
mask_pad
=
\
self
.
time_reduction_layer
(
xs
,
xs_lens
,
att_mask
,
mask_pad
)
pos_emb
=
pos_emb
[:,
::
2
,
:]
index
+=
1
if
self
.
recover_idx
is
not
None
:
if
self
.
time_reduce
==
'recover'
and
i
in
self
.
recover_idx
:
index
-=
1
(
recover_tensor
,
recover_att_mask
,
recover_pos_emb
,
recover_mask_pad
)
\
=
recover_activations
[
index
]
# recover output length for ctc decode
xs
=
xs
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
2
,
1
).
flatten
(
1
,
2
)
xs
=
self
.
time_recover_layer
(
xs
)
recoverd_t
=
recover_tensor
.
size
(
1
)
xs
=
recover_tensor
+
xs
[:,
:
recoverd_t
,
:].
contiguous
()
att_mask
=
recover_att_mask
pos_emb
=
recover_pos_emb
mask_pad
=
recover_mask_pad
if
att_mask
.
size
(
1
)
!=
0
:
xs
=
xs
.
masked_fill
(
~
att_mask
[:,
0
,
:].
unsqueeze
(
-
1
),
0.0
)
factor
=
self
.
calculate_downsampling_factor
(
i
)
xs
,
_
,
new_att_cache
,
new_cnn_cache
=
layer
(
xs
,
att_mask
,
pos_emb
,
att_cache
=
att_cache
[
i
:
i
+
1
][:,
:,
::
factor
,
:]
[:,
:,
:
pos_emb
.
size
(
1
)
-
xs
.
size
(
1
),
:]
if
elayers
>
0
else
att_cache
[:,
:,
::
factor
,
:],
cnn_cache
=
cnn_cache
[
i
]
if
cnn_cache
.
size
(
0
)
>
0
else
cnn_cache
)
# NOTE(xcsong): After layer.forward
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
cached_att
\
=
new_att_cache
[:,
:,
next_cache_start
//
factor
:,
:]
cached_cnn
=
new_cnn_cache
.
unsqueeze
(
0
)
cached_att
=
cached_att
.
unsqueeze
(
3
).
\
repeat
(
1
,
1
,
1
,
factor
,
1
).
flatten
(
2
,
3
)
if
i
==
0
:
# record length for the first block as max length
max_att_len
=
cached_att
.
size
(
2
)
r_att_cache
.
append
(
cached_att
[:,
:,
:
max_att_len
,
:])
r_cnn_cache
.
append
(
cached_cnn
)
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
# ? may be larger than cache_t1, it depends on required_cache_size
r_att_cache
=
torch
.
cat
(
r_att_cache
,
dim
=
0
)
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
r_cnn_cache
=
torch
.
cat
(
r_cnn_cache
,
dim
=
0
)
if
self
.
final_proj
is
not
None
:
xs
=
self
.
final_proj
(
xs
)
return
(
xs
,
r_att_cache
,
r_cnn_cache
)
def
forward_chunk_by_chunk
(
self
,
xs
:
torch
.
Tensor
,
decoding_chunk_size
:
int
,
num_decoding_left_chunks
:
int
=
-
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
""" Forward input chunk by chunk with chunk_size like a streaming
fashion
Here we should pay special attention to computation cache in the
streaming style forward chunk by chunk. Three things should be taken
into account for computation in the current network:
1. transformer/conformer encoder layers output cache
2. convolution in conformer
3. convolution in subsampling
However, we don't implement subsampling cache for:
1. We can control subsampling module to output the right result by
overlapping input instead of cache left context, even though it
wastes some computation, but subsampling only takes a very
small fraction of computation in the whole model.
2. Typically, there are several covolution layers with subsampling
in subsampling module, it is tricky and complicated to do cache
with different convolution layers with different subsampling
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
Args:
xs (torch.Tensor): (1, max_len, dim)
chunk_size (int): decoding chunk size
"""
assert
decoding_chunk_size
>
0
# The model is trained by static or dynamic chunk
assert
self
.
static_chunk_size
>
0
or
self
.
use_dynamic_chunk
subsampling
=
self
.
embed
.
subsampling_rate
context
=
self
.
embed
.
right_context
+
1
# Add current frame
stride
=
subsampling
*
decoding_chunk_size
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
num_frames
=
xs
.
size
(
1
)
att_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
),
device
=
xs
.
device
)
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
),
device
=
xs
.
device
)
outputs
=
[]
offset
=
0
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
# Feed forward overlap input step by step
for
cur
in
range
(
0
,
num_frames
-
context
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
chunk_xs
=
xs
[:,
cur
:
end
,
:]
(
y
,
att_cache
,
cnn_cache
)
=
\
self
.
forward_chunk
(
chunk_xs
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
outputs
.
append
(
y
)
offset
+=
y
.
size
(
1
)
ys
=
torch
.
cat
(
outputs
,
1
)
masks
=
torch
.
ones
((
1
,
1
,
ys
.
size
(
1
)),
device
=
ys
.
device
,
dtype
=
torch
.
bool
)
return
ys
,
masks
examples/aishell/s0/wenet/squeezeformer/encoder_layer.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SqueezeformerEncoderLayer definition."""
import
torch
import
torch.nn
as
nn
from
typing
import
Optional
,
Tuple
class
SqueezeformerEncoderLayer
(
nn
.
Module
):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward1 (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
feed_forward2 (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
"""
def
__init__
(
self
,
size
:
int
,
self_attn
:
torch
.
nn
.
Module
,
feed_forward1
:
Optional
[
nn
.
Module
]
=
None
,
conv_module
:
Optional
[
nn
.
Module
]
=
None
,
feed_forward2
:
Optional
[
nn
.
Module
]
=
None
,
normalize_before
:
bool
=
False
,
dropout_rate
:
float
=
0.1
,
concat_after
:
bool
=
False
,
):
super
(
SqueezeformerEncoderLayer
,
self
).
__init__
()
self
.
size
=
size
self
.
self_attn
=
self_attn
self
.
layer_norm1
=
nn
.
LayerNorm
(
size
)
self
.
ffn1
=
feed_forward1
self
.
layer_norm2
=
nn
.
LayerNorm
(
size
)
self
.
conv_module
=
conv_module
self
.
layer_norm3
=
nn
.
LayerNorm
(
size
)
self
.
ffn2
=
feed_forward2
self
.
layer_norm4
=
nn
.
LayerNorm
(
size
)
self
.
normalize_before
=
normalize_before
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
concat_after
=
concat_after
if
concat_after
:
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
else
:
self
.
concat_linear
=
nn
.
Identity
()
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
pos_emb
:
torch
.
Tensor
,
mask_pad
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
att_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
)),
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
)),
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
# self attention module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
layer_norm1
(
x
)
x_att
,
new_att_cache
=
self
.
self_attn
(
x
,
x
,
x
,
mask
,
pos_emb
,
att_cache
)
if
self
.
concat_after
:
x_concat
=
torch
.
cat
((
x
,
x_att
),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
x_att
)
if
not
self
.
normalize_before
:
x
=
self
.
layer_norm1
(
x
)
# ffn module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
layer_norm2
(
x
)
x
=
self
.
ffn1
(
x
)
x
=
residual
+
self
.
dropout
(
x
)
if
not
self
.
normalize_before
:
x
=
self
.
layer_norm2
(
x
)
# conv module
new_cnn_cache
=
torch
.
zeros
((
0
,
0
,
0
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
layer_norm3
(
x
)
x
,
new_cnn_cache
=
self
.
conv_module
(
x
,
mask_pad
,
cnn_cache
)
x
=
residual
+
self
.
dropout
(
x
)
if
not
self
.
normalize_before
:
x
=
self
.
layer_norm3
(
x
)
# ffn module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
layer_norm4
(
x
)
x
=
self
.
ffn2
(
x
)
# we do not use dropout here since it is inside feed forward function
x
=
residual
+
self
.
dropout
(
x
)
if
not
self
.
normalize_before
:
x
=
self
.
layer_norm4
(
x
)
return
x
,
mask
,
new_att_cache
,
new_cnn_cache
examples/aishell/s0/wenet/squeezeformer/positionwise_feed_forward.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Ximalaya Inc (Yuguang Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Positionwise feed forward layer definition."""
import
torch
class
PositionwiseFeedForward
(
torch
.
nn
.
Module
):
"""Positionwise feed forward layer.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""
def
__init__
(
self
,
idim
:
int
,
hidden_units
:
int
,
dropout_rate
:
float
,
activation
:
torch
.
nn
.
Module
=
torch
.
nn
.
ReLU
(),
adaptive_scale
:
bool
=
False
,
init_weights
:
bool
=
False
):
"""Construct a PositionwiseFeedForward object."""
super
(
PositionwiseFeedForward
,
self
).
__init__
()
self
.
idim
=
idim
self
.
hidden_units
=
hidden_units
self
.
w_1
=
torch
.
nn
.
Linear
(
idim
,
hidden_units
)
self
.
activation
=
activation
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout_rate
)
self
.
w_2
=
torch
.
nn
.
Linear
(
hidden_units
,
idim
)
self
.
ada_scale
=
None
self
.
ada_bias
=
None
self
.
adaptive_scale
=
adaptive_scale
self
.
ada_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
([
1
,
1
,
idim
]),
requires_grad
=
adaptive_scale
)
self
.
ada_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
([
1
,
1
,
idim
]),
requires_grad
=
adaptive_scale
)
if
init_weights
:
self
.
init_weights
()
def
init_weights
(
self
):
ffn1_max
=
self
.
idim
**
-
0.5
ffn2_max
=
self
.
hidden_units
**
-
0.5
torch
.
nn
.
init
.
uniform_
(
self
.
w_1
.
weight
.
data
,
-
ffn1_max
,
ffn1_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
w_1
.
bias
.
data
,
-
ffn1_max
,
ffn1_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
w_2
.
weight
.
data
,
-
ffn2_max
,
ffn2_max
)
torch
.
nn
.
init
.
uniform_
(
self
.
w_2
.
bias
.
data
,
-
ffn2_max
,
ffn2_max
)
def
forward
(
self
,
xs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Forward function.
Args:
xs: input tensor (B, L, D)
Returns:
output tensor, (B, L, D)
"""
if
self
.
adaptive_scale
:
xs
=
self
.
ada_scale
*
xs
+
self
.
ada_bias
return
self
.
w_2
(
self
.
dropout
(
self
.
activation
(
self
.
w_1
(
xs
))))
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment