Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Conformer_pytorch
Commits
a7785cc6
Commit
a7785cc6
authored
Mar 26, 2024
by
Sugon_ldc
Browse files
delete soft link
parent
9a2a05ca
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3250 additions
and
0 deletions
+3250
-0
examples/aishell/s0/wenet/transformer/__pycache__/encoder_layer.cpython-38.pyc
...enet/transformer/__pycache__/encoder_layer.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/transformer/__pycache__/label_smoothing_loss.cpython-38.pyc
...ansformer/__pycache__/label_smoothing_loss.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/transformer/__pycache__/positionwise_feed_forward.cpython-38.pyc
...rmer/__pycache__/positionwise_feed_forward.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/transformer/__pycache__/subsampling.cpython-38.pyc
.../wenet/transformer/__pycache__/subsampling.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/transformer/__pycache__/swish.cpython-38.pyc
...ell/s0/wenet/transformer/__pycache__/swish.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/transformer/asr_model.py
examples/aishell/s0/wenet/transformer/asr_model.py
+904
-0
examples/aishell/s0/wenet/transformer/attention.py
examples/aishell/s0/wenet/transformer/attention.py
+312
-0
examples/aishell/s0/wenet/transformer/cmvn.py
examples/aishell/s0/wenet/transformer/cmvn.py
+46
-0
examples/aishell/s0/wenet/transformer/convolution.py
examples/aishell/s0/wenet/transformer/convolution.py
+146
-0
examples/aishell/s0/wenet/transformer/ctc.py
examples/aishell/s0/wenet/transformer/ctc.py
+84
-0
examples/aishell/s0/wenet/transformer/decoder.py
examples/aishell/s0/wenet/transformer/decoder.py
+299
-0
examples/aishell/s0/wenet/transformer/decoder_layer.py
examples/aishell/s0/wenet/transformer/decoder_layer.py
+151
-0
examples/aishell/s0/wenet/transformer/embedding.py
examples/aishell/s0/wenet/transformer/embedding.py
+162
-0
examples/aishell/s0/wenet/transformer/encoder.py
examples/aishell/s0/wenet/transformer/encoder.py
+462
-0
examples/aishell/s0/wenet/transformer/encoder_layer.py
examples/aishell/s0/wenet/transformer/encoder_layer.py
+269
-0
examples/aishell/s0/wenet/transformer/label_smoothing_loss.py
...ples/aishell/s0/wenet/transformer/label_smoothing_loss.py
+96
-0
examples/aishell/s0/wenet/transformer/positionwise_feed_forward.py
...aishell/s0/wenet/transformer/positionwise_feed_forward.py
+53
-0
examples/aishell/s0/wenet/transformer/subsampling.py
examples/aishell/s0/wenet/transformer/subsampling.py
+240
-0
examples/aishell/s0/wenet/transformer/swish.py
examples/aishell/s0/wenet/transformer/swish.py
+26
-0
examples/aishell/s0/wenet/utils/__pycache__/checkpoint.cpython-38.pyc
...hell/s0/wenet/utils/__pycache__/checkpoint.cpython-38.pyc
+0
-0
No files found.
examples/aishell/s0/wenet/transformer/__pycache__/encoder_layer.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/transformer/__pycache__/label_smoothing_loss.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/transformer/__pycache__/positionwise_feed_forward.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/transformer/__pycache__/subsampling.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/transformer/__pycache__/swish.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/transformer/asr_model.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch.nn.utils.rnn
import
pad_sequence
try
:
import
k2
from
icefall.utils
import
get_texts
from
icefall.decode
import
get_lattice
,
Nbest
,
one_best_decoding
except
ImportError
:
print
(
'Failed to import k2 and icefall.
\
Notice that they are necessary for hlg_onebest and hlg_rescore'
)
from
wenet.transformer.ctc
import
CTC
from
wenet.transformer.decoder
import
TransformerDecoder
from
wenet.transformer.encoder
import
TransformerEncoder
from
wenet.transformer.label_smoothing_loss
import
LabelSmoothingLoss
from
wenet.utils.common
import
(
IGNORE_ID
,
add_sos_eos
,
log_add
,
remove_duplicates_and_blank
,
th_accuracy
,
reverse_pad_list
)
from
wenet.utils.mask
import
(
make_pad_mask
,
mask_finished_preds
,
mask_finished_scores
,
subsequent_mask
)
class
ASRModel
(
torch
.
nn
.
Module
):
"""CTC-attention hybrid Encoder-Decoder model"""
def
__init__
(
self
,
vocab_size
:
int
,
encoder
:
TransformerEncoder
,
decoder
:
TransformerDecoder
,
ctc
:
CTC
,
ctc_weight
:
float
=
0.5
,
ignore_id
:
int
=
IGNORE_ID
,
reverse_weight
:
float
=
0.0
,
lsm_weight
:
float
=
0.0
,
length_normalized_loss
:
bool
=
False
,
):
assert
0.0
<=
ctc_weight
<=
1.0
,
ctc_weight
super
().
__init__
()
# note that eos is the same as sos (equivalent ID)
self
.
sos
=
vocab_size
-
1
self
.
eos
=
vocab_size
-
1
self
.
vocab_size
=
vocab_size
self
.
ignore_id
=
ignore_id
self
.
ctc_weight
=
ctc_weight
self
.
reverse_weight
=
reverse_weight
self
.
encoder
=
encoder
self
.
decoder
=
decoder
self
.
ctc
=
ctc
self
.
criterion_att
=
LabelSmoothingLoss
(
size
=
vocab_size
,
padding_idx
=
ignore_id
,
smoothing
=
lsm_weight
,
normalize_length
=
length_normalized_loss
,
)
def
forward
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
text
:
torch
.
Tensor
,
text_lengths
:
torch
.
Tensor
,
)
->
Dict
[
str
,
Optional
[
torch
.
Tensor
]]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert
text_lengths
.
dim
()
==
1
,
text_lengths
.
shape
# Check that batch_size is unified
assert
(
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
==
text
.
shape
[
0
]
==
text_lengths
.
shape
[
0
]),
(
speech
.
shape
,
speech_lengths
.
shape
,
text
.
shape
,
text_lengths
.
shape
)
# 1. Encoder
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
# 2a. Attention-decoder branch
if
self
.
ctc_weight
!=
1.0
:
loss_att
,
acc_att
=
self
.
_calc_att_loss
(
encoder_out
,
encoder_mask
,
text
,
text_lengths
)
else
:
loss_att
=
None
# 2b. CTC branch
if
self
.
ctc_weight
!=
0.0
:
loss_ctc
=
self
.
ctc
(
encoder_out
,
encoder_out_lens
,
text
,
text_lengths
)
else
:
loss_ctc
=
None
if
loss_ctc
is
None
:
loss
=
loss_att
elif
loss_att
is
None
:
loss
=
loss_ctc
else
:
loss
=
self
.
ctc_weight
*
loss_ctc
+
(
1
-
self
.
ctc_weight
)
*
loss_att
return
{
"loss"
:
loss
,
"loss_att"
:
loss_att
,
"loss_ctc"
:
loss_ctc
}
def
_calc_att_loss
(
self
,
encoder_out
:
torch
.
Tensor
,
encoder_mask
:
torch
.
Tensor
,
ys_pad
:
torch
.
Tensor
,
ys_pad_lens
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
float
]:
ys_in_pad
,
ys_out_pad
=
add_sos_eos
(
ys_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
ys_in_lens
=
ys_pad_lens
+
1
# reverse the seq, used for right to left decoder
r_ys_pad
=
reverse_pad_list
(
ys_pad
,
ys_pad_lens
,
float
(
self
.
ignore_id
))
r_ys_in_pad
,
r_ys_out_pad
=
add_sos_eos
(
r_ys_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
# 1. Forward decoder
decoder_out
,
r_decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
ys_in_pad
,
ys_in_lens
,
r_ys_in_pad
,
self
.
reverse_weight
)
# 2. Compute attention loss
loss_att
=
self
.
criterion_att
(
decoder_out
,
ys_out_pad
)
r_loss_att
=
torch
.
tensor
(
0.0
)
if
self
.
reverse_weight
>
0.0
:
r_loss_att
=
self
.
criterion_att
(
r_decoder_out
,
r_ys_out_pad
)
loss_att
=
loss_att
*
(
1
-
self
.
reverse_weight
)
+
r_loss_att
*
self
.
reverse_weight
acc_att
=
th_accuracy
(
decoder_out
.
view
(
-
1
,
self
.
vocab_size
),
ys_out_pad
,
ignore_label
=
self
.
ignore_id
,
)
return
loss_att
,
acc_att
def
_forward_encoder
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
decoding_chunk_size
:
int
=
-
1
,
num_decoding_left_chunks
:
int
=
-
1
,
simulate_streaming
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Let's assume B = batch_size
# 1. Encoder
if
simulate_streaming
and
decoding_chunk_size
>
0
:
encoder_out
,
encoder_mask
=
self
.
encoder
.
forward_chunk_by_chunk
(
speech
,
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
)
# (B, maxlen, encoder_dim)
else
:
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
=
decoding_chunk_size
,
num_decoding_left_chunks
=
num_decoding_left_chunks
)
# (B, maxlen, encoder_dim)
return
encoder_out
,
encoder_mask
def
recognize
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
beam_size
:
int
=
10
,
decoding_chunk_size
:
int
=
-
1
,
num_decoding_left_chunks
:
int
=
-
1
,
simulate_streaming
:
bool
=
False
,
)
->
torch
.
Tensor
:
""" Apply beam search on attention decoder
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
torch.Tensor: decoding result, (batch, max_result_len)
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
device
=
speech
.
device
batch_size
=
speech
.
shape
[
0
]
# Let's assume B = batch_size and N = beam_size
# 1. Encoder
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
# (B, maxlen, encoder_dim)
maxlen
=
encoder_out
.
size
(
1
)
encoder_dim
=
encoder_out
.
size
(
2
)
running_size
=
batch_size
*
beam_size
encoder_out
=
encoder_out
.
unsqueeze
(
1
).
repeat
(
1
,
beam_size
,
1
,
1
).
view
(
running_size
,
maxlen
,
encoder_dim
)
# (B*N, maxlen, encoder_dim)
encoder_mask
=
encoder_mask
.
unsqueeze
(
1
).
repeat
(
1
,
beam_size
,
1
,
1
).
view
(
running_size
,
1
,
maxlen
)
# (B*N, 1, max_len)
hyps
=
torch
.
ones
([
running_size
,
1
],
dtype
=
torch
.
long
,
device
=
device
).
fill_
(
self
.
sos
)
# (B*N, 1)
scores
=
torch
.
tensor
([
0.0
]
+
[
-
float
(
'inf'
)]
*
(
beam_size
-
1
),
dtype
=
torch
.
float
)
scores
=
scores
.
to
(
device
).
repeat
([
batch_size
]).
unsqueeze
(
1
).
to
(
device
)
# (B*N, 1)
end_flag
=
torch
.
zeros_like
(
scores
,
dtype
=
torch
.
bool
,
device
=
device
)
cache
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# 2. Decoder forward step by step
for
i
in
range
(
1
,
maxlen
+
1
):
# Stop if all batch and all beam produce eos
if
end_flag
.
sum
()
==
running_size
:
break
# 2.1 Forward decoder step
hyps_mask
=
subsequent_mask
(
i
).
unsqueeze
(
0
).
repeat
(
running_size
,
1
,
1
).
to
(
device
)
# (B*N, i, i)
# logp: (B*N, vocab)
logp
,
cache
=
self
.
decoder
.
forward_one_step
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_mask
,
cache
)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (B*N, N)
top_k_logp
=
mask_finished_scores
(
top_k_logp
,
end_flag
)
top_k_index
=
mask_finished_preds
(
top_k_index
,
end_flag
,
self
.
eos
)
# 2.3 Second beam prune: select topk score with history
scores
=
scores
+
top_k_logp
# (B*N, N), broadcast add
scores
=
scores
.
view
(
batch_size
,
beam_size
*
beam_size
)
# (B, N*N)
scores
,
offset_k_index
=
scores
.
topk
(
k
=
beam_size
)
# (B, N)
# Update cache to be consistent with new topk scores / hyps
cache_index
=
(
offset_k_index
//
beam_size
).
view
(
-
1
)
# (B*N)
base_cache_index
=
(
torch
.
arange
(
batch_size
,
device
=
device
).
view
(
-
1
,
1
).
repeat
([
1
,
beam_size
])
*
beam_size
).
view
(
-
1
)
# (B*N)
cache_index
=
base_cache_index
+
cache_index
cache
=
[
torch
.
index_select
(
c
,
dim
=
0
,
index
=
cache_index
)
for
c
in
cache
]
scores
=
scores
.
view
(
-
1
,
1
)
# (B*N, 1)
# 2.4. Compute base index in top_k_index,
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index
base_k_index
=
torch
.
arange
(
batch_size
,
device
=
device
).
view
(
-
1
,
1
).
repeat
([
1
,
beam_size
])
# (B, N)
base_k_index
=
base_k_index
*
beam_size
*
beam_size
best_k_index
=
base_k_index
.
view
(
-
1
)
+
offset_k_index
.
view
(
-
1
)
# (B*N)
# 2.5 Update best hyps
best_k_pred
=
torch
.
index_select
(
top_k_index
.
view
(
-
1
),
dim
=-
1
,
index
=
best_k_index
)
# (B*N)
best_hyps_index
=
best_k_index
//
beam_size
last_best_k_hyps
=
torch
.
index_select
(
hyps
,
dim
=
0
,
index
=
best_hyps_index
)
# (B*N, i)
hyps
=
torch
.
cat
((
last_best_k_hyps
,
best_k_pred
.
view
(
-
1
,
1
)),
dim
=
1
)
# (B*N, i+1)
# 2.6 Update end flag
end_flag
=
torch
.
eq
(
hyps
[:,
-
1
],
self
.
eos
).
view
(
-
1
,
1
)
# 3. Select best of best
scores
=
scores
.
view
(
batch_size
,
beam_size
)
# TODO: length normalization
best_scores
,
best_index
=
scores
.
max
(
dim
=-
1
)
best_hyps_index
=
best_index
+
torch
.
arange
(
batch_size
,
dtype
=
torch
.
long
,
device
=
device
)
*
beam_size
best_hyps
=
torch
.
index_select
(
hyps
,
dim
=
0
,
index
=
best_hyps_index
)
best_hyps
=
best_hyps
[:,
1
:]
return
best_hyps
,
best_scores
def
ctc_greedy_search
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
decoding_chunk_size
:
int
=
-
1
,
num_decoding_left_chunks
:
int
=
-
1
,
simulate_streaming
:
bool
=
False
,
)
->
List
[
List
[
int
]]:
""" Apply CTC greedy search
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[List[int]]: best path result
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
batch_size
=
speech
.
shape
[
0
]
# Let's assume B = batch_size
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
# (B, maxlen, encoder_dim)
maxlen
=
encoder_out
.
size
(
1
)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
ctc_probs
=
self
.
ctc
.
log_softmax
(
encoder_out
)
# (B, maxlen, vocab_size)
topk_prob
,
topk_index
=
ctc_probs
.
topk
(
1
,
dim
=
2
)
# (B, maxlen, 1)
topk_index
=
topk_index
.
view
(
batch_size
,
maxlen
)
# (B, maxlen)
mask
=
make_pad_mask
(
encoder_out_lens
,
maxlen
)
# (B, maxlen)
topk_index
=
topk_index
.
masked_fill_
(
mask
,
self
.
eos
)
# (B, maxlen)
hyps
=
[
hyp
.
tolist
()
for
hyp
in
topk_index
]
scores
=
topk_prob
.
max
(
1
)
hyps
=
[
remove_duplicates_and_blank
(
hyp
)
for
hyp
in
hyps
]
return
hyps
,
scores
def
_ctc_prefix_beam_search
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
beam_size
:
int
,
decoding_chunk_size
:
int
=
-
1
,
num_decoding_left_chunks
:
int
=
-
1
,
simulate_streaming
:
bool
=
False
,
)
->
Tuple
[
List
[
List
[
int
]],
torch
.
Tensor
]:
""" CTC prefix beam search inner implementation
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[List[int]]: nbest results
torch.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
batch_size
=
speech
.
shape
[
0
]
# For CTC prefix beam search, we only support batch_size=1
assert
batch_size
==
1
# Let's assume B = batch_size and N = beam_size
# 1. Encoder forward and get CTC score
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
# (B, maxlen, encoder_dim)
maxlen
=
encoder_out
.
size
(
1
)
ctc_probs
=
self
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
cur_hyps
=
[(
tuple
(),
(
0.0
,
-
float
(
'inf'
)))]
# 2. CTC beam search step by step
for
t
in
range
(
0
,
maxlen
):
logp
=
ctc_probs
[
t
]
# (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps
=
defaultdict
(
lambda
:
(
-
float
(
'inf'
),
-
float
(
'inf'
)))
# 2.1 First beam prune: select topk best
top_k_logp
,
top_k_index
=
logp
.
topk
(
beam_size
)
# (beam_size,)
for
s
in
top_k_index
:
s
=
s
.
item
()
ps
=
logp
[
s
].
item
()
for
prefix
,
(
pb
,
pnb
)
in
cur_hyps
:
last
=
prefix
[
-
1
]
if
len
(
prefix
)
>
0
else
None
if
s
==
0
:
# blank
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pb
=
log_add
([
n_pb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
elif
s
==
last
:
# Update *ss -> *s;
n_pb
,
n_pnb
=
next_hyps
[
prefix
]
n_pnb
=
log_add
([
n_pnb
,
pnb
+
ps
])
next_hyps
[
prefix
]
=
(
n_pb
,
n_pnb
)
# Update *s-s -> *ss, - is for blank
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
else
:
n_prefix
=
prefix
+
(
s
,
)
n_pb
,
n_pnb
=
next_hyps
[
n_prefix
]
n_pnb
=
log_add
([
n_pnb
,
pb
+
ps
,
pnb
+
ps
])
next_hyps
[
n_prefix
]
=
(
n_pb
,
n_pnb
)
# 2.2 Second beam prune
next_hyps
=
sorted
(
next_hyps
.
items
(),
key
=
lambda
x
:
log_add
(
list
(
x
[
1
])),
reverse
=
True
)
cur_hyps
=
next_hyps
[:
beam_size
]
hyps
=
[(
y
[
0
],
log_add
([
y
[
1
][
0
],
y
[
1
][
1
]]))
for
y
in
cur_hyps
]
return
hyps
,
encoder_out
def
ctc_prefix_beam_search
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
beam_size
:
int
,
decoding_chunk_size
:
int
=
-
1
,
num_decoding_left_chunks
:
int
=
-
1
,
simulate_streaming
:
bool
=
False
,
)
->
List
[
int
]:
""" Apply CTC prefix beam search
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[int]: CTC prefix beam search nbest results
"""
hyps
,
_
=
self
.
_ctc_prefix_beam_search
(
speech
,
speech_lengths
,
beam_size
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
return
hyps
[
0
]
def
attention_rescoring
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
beam_size
:
int
,
decoding_chunk_size
:
int
=
-
1
,
num_decoding_left_chunks
:
int
=
-
1
,
ctc_weight
:
float
=
0.0
,
simulate_streaming
:
bool
=
False
,
reverse_weight
:
float
=
0.0
,
)
->
List
[
int
]:
""" Apply attention rescoring decoding, CTC prefix beam search
is applied first to get nbest, then we resoring the nbest on
attention decoder with corresponding encoder out
Args:
speech (torch.Tensor): (batch, max_len, feat_dim)
speech_length (torch.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
reverse_weight (float): right to left decoder weight
ctc_weight (float): ctc score weight
Returns:
List[int]: Attention rescoring result
"""
assert
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
assert
decoding_chunk_size
!=
0
if
reverse_weight
>
0.0
:
# decoder should be a bitransformer decoder if reverse_weight > 0.0
assert
hasattr
(
self
.
decoder
,
'right_decoder'
)
device
=
speech
.
device
batch_size
=
speech
.
shape
[
0
]
# For attention rescoring we only support batch_size=1
assert
batch_size
==
1
# encoder_out: (1, maxlen, encoder_dim), len(hyps) = beam_size
hyps
,
encoder_out
=
self
.
_ctc_prefix_beam_search
(
speech
,
speech_lengths
,
beam_size
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
assert
len
(
hyps
)
==
beam_size
hyps_pad
=
pad_sequence
([
torch
.
tensor
(
hyp
[
0
],
device
=
device
,
dtype
=
torch
.
long
)
for
hyp
in
hyps
],
True
,
self
.
ignore_id
)
# (beam_size, max_hyps_len)
ori_hyps_pad
=
hyps_pad
hyps_lens
=
torch
.
tensor
([
len
(
hyp
[
0
])
for
hyp
in
hyps
],
device
=
device
,
dtype
=
torch
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out
=
encoder_out
.
repeat
(
beam_size
,
1
,
1
)
encoder_mask
=
torch
.
ones
(
beam_size
,
1
,
encoder_out
.
size
(
1
),
dtype
=
torch
.
bool
,
device
=
device
)
# used for right to left decoder
r_hyps_pad
=
reverse_pad_list
(
ori_hyps_pad
,
hyps_lens
,
self
.
ignore_id
)
r_hyps_pad
,
_
=
add_sos_eos
(
r_hyps_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
decoder_out
,
r_decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
,
r_hyps_pad
,
reverse_weight
)
# (beam_size, max_hyps_len, vocab_size)
decoder_out
=
torch
.
nn
.
functional
.
log_softmax
(
decoder_out
,
dim
=-
1
)
decoder_out
=
decoder_out
.
cpu
().
numpy
()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out
=
torch
.
nn
.
functional
.
log_softmax
(
r_decoder_out
,
dim
=-
1
)
r_decoder_out
=
r_decoder_out
.
cpu
().
numpy
()
# Only use decoder score for rescoring
best_score
=
-
float
(
'inf'
)
best_index
=
0
for
i
,
hyp
in
enumerate
(
hyps
):
score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
score
+=
decoder_out
[
i
][
j
][
w
]
score
+=
decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
eos
]
# add right to left decoder score
if
reverse_weight
>
0
:
r_score
=
0.0
for
j
,
w
in
enumerate
(
hyp
[
0
]):
r_score
+=
r_decoder_out
[
i
][
len
(
hyp
[
0
])
-
j
-
1
][
w
]
r_score
+=
r_decoder_out
[
i
][
len
(
hyp
[
0
])][
self
.
eos
]
score
=
score
*
(
1
-
reverse_weight
)
+
r_score
*
reverse_weight
# add ctc score
score
+=
hyp
[
1
]
*
ctc_weight
if
score
>
best_score
:
best_score
=
score
best_index
=
i
return
hyps
[
best_index
][
0
],
best_score
def
load_hlg_resource_if_necessary
(
self
,
hlg
,
word
):
if
not
hasattr
(
self
,
'hlg'
):
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
self
.
hlg
=
k2
.
Fsa
.
from_dict
(
torch
.
load
(
hlg
,
map_location
=
device
))
if
not
hasattr
(
self
.
hlg
,
"lm_scores"
):
self
.
hlg
.
lm_scores
=
self
.
hlg
.
scores
.
clone
()
if
not
hasattr
(
self
,
'word_table'
):
self
.
word_table
=
{}
with
open
(
word
,
'r'
)
as
fin
:
for
line
in
fin
:
arr
=
line
.
strip
().
split
()
assert
len
(
arr
)
==
2
self
.
word_table
[
int
(
arr
[
1
])]
=
arr
[
0
]
@
torch
.
no_grad
()
def
hlg_onebest
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
decoding_chunk_size
:
int
=
-
1
,
num_decoding_left_chunks
:
int
=
-
1
,
simulate_streaming
:
bool
=
False
,
hlg
:
str
=
''
,
word
:
str
=
''
,
symbol_table
:
Dict
[
str
,
int
]
=
None
,
)
->
List
[
int
]:
self
.
load_hlg_resource_if_necessary
(
hlg
,
word
)
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
# (B, maxlen, encoder_dim)
ctc_probs
=
self
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
supervision_segments
=
torch
.
stack
(
(
torch
.
arange
(
len
(
encoder_mask
)),
torch
.
zeros
(
len
(
encoder_mask
)),
encoder_mask
.
squeeze
(
dim
=
1
).
sum
(
dim
=
1
).
cpu
()),
1
,).
to
(
torch
.
int32
)
lattice
=
get_lattice
(
nnet_output
=
ctc_probs
,
decoding_graph
=
self
.
hlg
,
supervision_segments
=
supervision_segments
,
search_beam
=
20
,
output_beam
=
7
,
min_active_states
=
30
,
max_active_states
=
10000
,
subsampling_factor
=
4
)
best_path
=
one_best_decoding
(
lattice
=
lattice
,
use_double_scores
=
True
)
hyps
=
get_texts
(
best_path
)
hyps
=
[[
symbol_table
[
k
]
for
j
in
i
for
k
in
self
.
word_table
[
j
]]
for
i
in
hyps
]
return
hyps
@
torch
.
no_grad
()
def
hlg_rescore
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
decoding_chunk_size
:
int
=
-
1
,
num_decoding_left_chunks
:
int
=
-
1
,
simulate_streaming
:
bool
=
False
,
lm_scale
:
float
=
0
,
decoder_scale
:
float
=
0
,
r_decoder_scale
:
float
=
0
,
hlg
:
str
=
''
,
word
:
str
=
''
,
symbol_table
:
Dict
[
str
,
int
]
=
None
,
)
->
List
[
int
]:
self
.
load_hlg_resource_if_necessary
(
hlg
,
word
)
device
=
speech
.
device
encoder_out
,
encoder_mask
=
self
.
_forward_encoder
(
speech
,
speech_lengths
,
decoding_chunk_size
,
num_decoding_left_chunks
,
simulate_streaming
)
# (B, maxlen, encoder_dim)
ctc_probs
=
self
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
supervision_segments
=
torch
.
stack
(
(
torch
.
arange
(
len
(
encoder_mask
)),
torch
.
zeros
(
len
(
encoder_mask
)),
encoder_mask
.
squeeze
(
dim
=
1
).
sum
(
dim
=
1
).
cpu
()),
1
,).
to
(
torch
.
int32
)
lattice
=
get_lattice
(
nnet_output
=
ctc_probs
,
decoding_graph
=
self
.
hlg
,
supervision_segments
=
supervision_segments
,
search_beam
=
20
,
output_beam
=
7
,
min_active_states
=
30
,
max_active_states
=
10000
,
subsampling_factor
=
4
)
nbest
=
Nbest
.
from_lattice
(
lattice
=
lattice
,
num_paths
=
100
,
use_double_scores
=
True
,
nbest_scale
=
0.5
,)
nbest
=
nbest
.
intersect
(
lattice
)
assert
hasattr
(
nbest
.
fsa
,
"lm_scores"
)
assert
hasattr
(
nbest
.
fsa
,
"tokens"
)
assert
isinstance
(
nbest
.
fsa
.
tokens
,
torch
.
Tensor
)
tokens_shape
=
nbest
.
fsa
.
arcs
.
shape
().
remove_axis
(
1
)
tokens
=
k2
.
RaggedTensor
(
tokens_shape
,
nbest
.
fsa
.
tokens
)
tokens
=
tokens
.
remove_values_leq
(
0
)
hyps
=
tokens
.
tolist
()
# cal attention_score
hyps_pad
=
pad_sequence
([
torch
.
tensor
(
hyp
,
device
=
device
,
dtype
=
torch
.
long
)
for
hyp
in
hyps
],
True
,
self
.
ignore_id
)
# (beam_size, max_hyps_len)
ori_hyps_pad
=
hyps_pad
hyps_lens
=
torch
.
tensor
([
len
(
hyp
)
for
hyp
in
hyps
],
device
=
device
,
dtype
=
torch
.
long
)
# (beam_size,)
hyps_pad
,
_
=
add_sos_eos
(
hyps_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
hyps_lens
=
hyps_lens
+
1
# Add <sos> at begining
encoder_out_repeat
=
[]
tot_scores
=
nbest
.
tot_scores
()
repeats
=
[
tot_scores
[
i
].
shape
[
0
]
for
i
in
range
(
tot_scores
.
dim0
)]
for
i
in
range
(
len
(
encoder_out
)):
encoder_out_repeat
.
append
(
encoder_out
[
i
:
i
+
1
].
repeat
(
repeats
[
i
],
1
,
1
))
encoder_out
=
torch
.
concat
(
encoder_out_repeat
,
dim
=
0
)
encoder_mask
=
torch
.
ones
(
encoder_out
.
size
(
0
),
1
,
encoder_out
.
size
(
1
),
dtype
=
torch
.
bool
,
device
=
device
)
# used for right to left decoder
r_hyps_pad
=
reverse_pad_list
(
ori_hyps_pad
,
hyps_lens
,
self
.
ignore_id
)
r_hyps_pad
,
_
=
add_sos_eos
(
r_hyps_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
reverse_weight
=
0.5
decoder_out
,
r_decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad
,
hyps_lens
,
r_hyps_pad
,
reverse_weight
)
# (beam_size, max_hyps_len, vocab_size)
decoder_out
=
torch
.
nn
.
functional
.
log_softmax
(
decoder_out
,
dim
=-
1
)
decoder_out
=
decoder_out
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out
=
torch
.
nn
.
functional
.
log_softmax
(
r_decoder_out
,
dim
=-
1
)
r_decoder_out
=
r_decoder_out
decoder_scores
=
torch
.
tensor
([
sum
([
decoder_out
[
i
,
j
,
hyps
[
i
][
j
]]
for
j
in
range
(
len
(
hyps
[
i
]))])
for
i
in
range
(
len
(
hyps
))],
device
=
device
)
r_decoder_scores
=
[]
for
i
in
range
(
len
(
hyps
)):
score
=
0
for
j
in
range
(
len
(
hyps
[
i
])):
score
+=
r_decoder_out
[
i
,
len
(
hyps
[
i
])
-
j
-
1
,
hyps
[
i
][
j
]]
score
+=
r_decoder_out
[
i
,
len
(
hyps
[
i
]),
self
.
eos
]
r_decoder_scores
.
append
(
score
)
r_decoder_scores
=
torch
.
tensor
(
r_decoder_scores
,
device
=
device
)
am_scores
=
nbest
.
compute_am_scores
()
ngram_lm_scores
=
nbest
.
compute_lm_scores
()
tot_scores
=
am_scores
.
values
+
lm_scale
*
ngram_lm_scores
.
values
+
\
decoder_scale
*
decoder_scores
+
r_decoder_scale
*
r_decoder_scores
ragged_tot_scores
=
k2
.
RaggedTensor
(
nbest
.
shape
,
tot_scores
)
max_indexes
=
ragged_tot_scores
.
argmax
()
best_path
=
k2
.
index_fsa
(
nbest
.
fsa
,
max_indexes
)
hyps
=
get_texts
(
best_path
)
hyps
=
[[
symbol_table
[
k
]
for
j
in
i
for
k
in
self
.
word_table
[
j
]]
for
i
in
hyps
]
return
hyps
@
torch
.
jit
.
export
def
subsampling_rate
(
self
)
->
int
:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return
self
.
encoder
.
embed
.
subsampling_rate
@
torch
.
jit
.
export
def
right_context
(
self
)
->
int
:
""" Export interface for c++ call, return right_context of the model
"""
return
self
.
encoder
.
embed
.
right_context
@
torch
.
jit
.
export
def
sos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return sos symbol id of the model
"""
return
self
.
sos
@
torch
.
jit
.
export
def
eos_symbol
(
self
)
->
int
:
""" Export interface for c++ call, return eos symbol id of the model
"""
return
self
.
eos
@
torch
.
jit
.
export
def
forward_encoder_chunk
(
self
,
xs
:
torch
.
Tensor
,
offset
:
int
,
required_cache_size
:
int
,
att_cache
:
torch
.
Tensor
=
torch
.
zeros
(
0
,
0
,
0
,
0
),
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
(
0
,
0
,
0
,
0
),
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
Args:
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate +
\
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, ?, d_k * 2)
depending on required_cache_size.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
return
self
.
encoder
.
forward_chunk
(
xs
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
@
torch
.
jit
.
export
def
ctc_activation
(
self
,
xs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (torch.Tensor): encoder output
Returns:
torch.Tensor: activation before ctc
"""
return
self
.
ctc
.
log_softmax
(
xs
)
@
torch
.
jit
.
export
def
is_bidirectional_decoder
(
self
)
->
bool
:
"""
Returns:
torch.Tensor: decoder output
"""
if
hasattr
(
self
.
decoder
,
'right_decoder'
):
return
True
else
:
return
False
@
torch
.
jit
.
export
def
forward_attention_decoder
(
self
,
hyps
:
torch
.
Tensor
,
hyps_lens
:
torch
.
Tensor
,
encoder_out
:
torch
.
Tensor
,
reverse_weight
:
float
=
0
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
""" Export interface for c++ call, forward decoder with multiple
hypothesis from ctc prefix beam search and one encoder output
Args:
hyps (torch.Tensor): hyps from ctc prefix beam search, already
pad sos at the begining
hyps_lens (torch.Tensor): length of each hyp in hyps
encoder_out (torch.Tensor): corresponding encoder output
r_hyps (torch.Tensor): hyps from ctc prefix beam search, already
pad eos at the begining which is used fo right to left decoder
reverse_weight: used for verfing whether used right to left decoder,
> 0 will use.
Returns:
torch.Tensor: decoder output
"""
assert
encoder_out
.
size
(
0
)
==
1
num_hyps
=
hyps
.
size
(
0
)
assert
hyps_lens
.
size
(
0
)
==
num_hyps
encoder_out
=
encoder_out
.
repeat
(
num_hyps
,
1
,
1
)
encoder_mask
=
torch
.
ones
(
num_hyps
,
1
,
encoder_out
.
size
(
1
),
dtype
=
torch
.
bool
,
device
=
encoder_out
.
device
)
# input for right to left decoder
# this hyps_lens has count <sos> token, we need minus it.
r_hyps_lens
=
hyps_lens
-
1
# this hyps has included <sos> token, so it should be
# convert the original hyps.
r_hyps
=
hyps
[:,
1
:]
# >>> r_hyps
# >>> tensor([[ 1, 2, 3],
# >>> [ 9, 8, 4],
# >>> [ 2, -1, -1]])
# >>> r_hyps_lens
# >>> tensor([3, 3, 1])
# NOTE(Mddct): `pad_sequence` is not supported by ONNX, it is used
# in `reverse_pad_list` thus we have to refine the below code.
# Issue: https://github.com/wenet-e2e/wenet/issues/1113
# Equal to:
# >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id))
# >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id)
max_len
=
torch
.
max
(
r_hyps_lens
)
index_range
=
torch
.
arange
(
0
,
max_len
,
1
).
to
(
encoder_out
.
device
)
seq_len_expand
=
r_hyps_lens
.
unsqueeze
(
1
)
seq_mask
=
seq_len_expand
>
index_range
# (beam, max_len)
# >>> seq_mask
# >>> tensor([[ True, True, True],
# >>> [ True, True, True],
# >>> [ True, False, False]])
index
=
(
seq_len_expand
-
1
)
-
index_range
# (beam, max_len)
# >>> index
# >>> tensor([[ 2, 1, 0],
# >>> [ 2, 1, 0],
# >>> [ 0, -1, -2]])
index
=
index
*
seq_mask
# >>> index
# >>> tensor([[2, 1, 0],
# >>> [2, 1, 0],
# >>> [0, 0, 0]])
r_hyps
=
torch
.
gather
(
r_hyps
,
1
,
index
)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, 2, 2]])
r_hyps
=
torch
.
where
(
seq_mask
,
r_hyps
,
self
.
eos
)
# >>> r_hyps
# >>> tensor([[3, 2, 1],
# >>> [4, 8, 9],
# >>> [2, eos, eos]])
r_hyps
=
torch
.
cat
([
hyps
[:,
0
:
1
],
r_hyps
],
dim
=
1
)
# >>> r_hyps
# >>> tensor([[sos, 3, 2, 1],
# >>> [sos, 4, 8, 9],
# >>> [sos, 2, eos, eos]])
decoder_out
,
r_decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps
,
hyps_lens
,
r_hyps
,
reverse_weight
)
# (num_hyps, max_hyps_len, vocab_size)
decoder_out
=
torch
.
nn
.
functional
.
log_softmax
(
decoder_out
,
dim
=-
1
)
# right to left decoder may be not used during decoding process,
# which depends on reverse_weight param.
# r_dccoder_out will be 0.0, if reverse_weight is 0.0
r_decoder_out
=
torch
.
nn
.
functional
.
log_softmax
(
r_decoder_out
,
dim
=-
1
)
return
decoder_out
,
r_decoder_out
examples/aishell/s0/wenet/transformer/attention.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-Head Attention layer definition."""
import
math
from
typing
import
Tuple
import
torch
from
torch
import
nn
class
MultiHeadedAttention
(
nn
.
Module
):
"""Multi-Head Attention layer.
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def
__init__
(
self
,
n_head
:
int
,
n_feat
:
int
,
dropout_rate
:
float
):
"""Construct an MultiHeadedAttention object."""
super
().
__init__
()
assert
n_feat
%
n_head
==
0
# We assume d_v always equals d_k
self
.
d_k
=
n_feat
//
n_head
self
.
h
=
n_head
self
.
linear_q
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
linear_k
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
linear_v
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
linear_out
=
nn
.
Linear
(
n_feat
,
n_feat
)
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_rate
)
def
forward_qkv
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
Returns:
torch.Tensor: Transformed query tensor, size
(#batch, n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, n_head, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
"""
n_batch
=
query
.
size
(
0
)
q
=
self
.
linear_q
(
query
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
k
=
self
.
linear_k
(
key
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
v
=
self
.
linear_v
(
value
).
view
(
n_batch
,
-
1
,
self
.
h
,
self
.
d_k
)
q
=
q
.
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
k
=
k
.
transpose
(
1
,
2
)
# (batch, head, time2, d_k)
v
=
v
.
transpose
(
1
,
2
)
# (batch, head, time2, d_k)
return
q
,
k
,
v
def
forward_attention
(
self
,
value
:
torch
.
Tensor
,
scores
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
)
)
->
torch
.
Tensor
:
"""Compute attention context vector.
Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch
=
value
.
size
(
0
)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if
mask
.
size
(
2
)
>
0
:
# time2 > 0
mask
=
mask
.
unsqueeze
(
1
).
eq
(
0
)
# (batch, 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask
=
mask
[:,
:,
:,
:
scores
.
size
(
-
1
)]
# (batch, 1, *, time2)
scores
=
scores
.
masked_fill
(
mask
,
-
float
(
'inf'
))
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
).
masked_fill
(
mask
,
0.0
)
# (batch, head, time1, time2)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else
:
attn
=
torch
.
softmax
(
scores
,
dim
=-
1
)
# (batch, head, time1, time2)
p_attn
=
self
.
dropout
(
attn
)
x
=
torch
.
matmul
(
p_attn
,
value
)
# (batch, head, time1, d_k)
x
=
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
n_batch
,
-
1
,
self
.
h
*
self
.
d_k
)
)
# (batch, time1, d_model)
return
self
.
linear_out
(
x
)
# (batch, time1, d_model)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
pos_emb
:
torch
.
Tensor
=
torch
.
empty
(
0
),
cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
))
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
1.When applying cross attention between decoder and encoder,
the batch padding mask for input is in (#batch, 1, T) shape.
2.When applying self attention of encoder,
the mask is in (#batch, T, T) shape.
3.When applying self attention of decoder,
the mask is in (#batch, L, L) shape.
4.If the different position in decoder see different block
of the encoder, such as Mocha, the passed in mask could be
in (#batch, L, T) shape. But there is no such case in current
Wenet.
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if
cache
.
size
(
0
)
>
0
:
key_cache
,
value_cache
=
torch
.
split
(
cache
,
cache
.
size
(
-
1
)
//
2
,
dim
=-
1
)
k
=
torch
.
cat
([
key_cache
,
k
],
dim
=
2
)
v
=
torch
.
cat
([
value_cache
,
v
],
dim
=
2
)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache
=
torch
.
cat
((
k
,
v
),
dim
=-
1
)
scores
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
2
,
-
1
))
/
math
.
sqrt
(
self
.
d_k
)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
class
RelPositionMultiHeadedAttention
(
MultiHeadedAttention
):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""
def
__init__
(
self
,
n_head
,
n_feat
,
dropout_rate
):
"""Construct an RelPositionMultiHeadedAttention object."""
super
().
__init__
(
n_head
,
n_feat
,
dropout_rate
)
# linear transformation for positional encoding
self
.
linear_pos
=
nn
.
Linear
(
n_feat
,
n_feat
,
bias
=
False
)
# these two learnable bias are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
self
.
pos_bias_u
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
))
self
.
pos_bias_v
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
h
,
self
.
d_k
))
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_u
)
torch
.
nn
.
init
.
xavier_uniform_
(
self
.
pos_bias_v
)
def
rel_shift
(
self
,
x
,
zero_triu
:
bool
=
False
):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
zero_triu (bool): If true, return the lower triangular part of
the matrix.
Returns:
torch.Tensor: Output tensor.
"""
zero_pad
=
torch
.
zeros
((
x
.
size
()[
0
],
x
.
size
()[
1
],
x
.
size
()[
2
],
1
),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
x_padded
=
torch
.
cat
([
zero_pad
,
x
],
dim
=-
1
)
x_padded
=
x_padded
.
view
(
x
.
size
()[
0
],
x
.
size
()[
1
],
x
.
size
(
3
)
+
1
,
x
.
size
(
2
))
x
=
x_padded
[:,
:,
1
:].
view_as
(
x
)
if
zero_triu
:
ones
=
torch
.
ones
((
x
.
size
(
2
),
x
.
size
(
3
)))
x
=
x
*
torch
.
tril
(
ones
,
x
.
size
(
3
)
-
x
.
size
(
2
))[
None
,
None
,
:,
:]
return
x
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
pos_emb
:
torch
.
Tensor
=
torch
.
empty
(
0
),
cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
))
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q
,
k
,
v
=
self
.
forward_qkv
(
query
,
key
,
value
)
q
=
q
.
transpose
(
1
,
2
)
# (batch, time1, head, d_k)
# NOTE(xcsong):
# when export onnx model, for 1st chunk, we feed
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
# and we will always do splitting and
# concatnation(this will simplify onnx export). Note that
# it's OK to concat & split zero-shaped tensors(see code below).
# when export jit model, for 1st chunk, we always feed
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
# >>> a = torch.ones((1, 2, 0, 4))
# >>> b = torch.ones((1, 2, 3, 4))
# >>> c = torch.cat((a, b), dim=2)
# >>> torch.equal(b, c) # True
# >>> d = torch.split(a, 2, dim=-1)
# >>> torch.equal(d[0], d[1]) # True
if
cache
.
size
(
0
)
>
0
:
key_cache
,
value_cache
=
torch
.
split
(
cache
,
cache
.
size
(
-
1
)
//
2
,
dim
=-
1
)
k
=
torch
.
cat
([
key_cache
,
k
],
dim
=
2
)
v
=
torch
.
cat
([
value_cache
,
v
],
dim
=
2
)
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
# non-trivial to calculate `next_cache_start` here.
new_cache
=
torch
.
cat
((
k
,
v
),
dim
=-
1
)
n_batch_pos
=
pos_emb
.
size
(
0
)
p
=
self
.
linear_pos
(
pos_emb
).
view
(
n_batch_pos
,
-
1
,
self
.
h
,
self
.
d_k
)
p
=
p
.
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u
=
(
q
+
self
.
pos_bias_u
).
transpose
(
1
,
2
)
# (batch, head, time1, d_k)
q_with_bias_v
=
(
q
+
self
.
pos_bias_v
).
transpose
(
1
,
2
)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac
=
torch
.
matmul
(
q_with_bias_u
,
k
.
transpose
(
-
2
,
-
1
))
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd
=
torch
.
matmul
(
q_with_bias_v
,
p
.
transpose
(
-
2
,
-
1
))
# Remove rel_shift since it is useless in speech recognition,
# and it requires special attention for streaming.
# matrix_bd = self.rel_shift(matrix_bd)
scores
=
(
matrix_ac
+
matrix_bd
)
/
math
.
sqrt
(
self
.
d_k
)
# (batch, head, time1, time2)
return
self
.
forward_attention
(
v
,
scores
,
mask
),
new_cache
examples/aishell/s0/wenet/transformer/cmvn.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
class
GlobalCMVN
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
mean
:
torch
.
Tensor
,
istd
:
torch
.
Tensor
,
norm_var
:
bool
=
True
):
"""
Args:
mean (torch.Tensor): mean stats
istd (torch.Tensor): inverse std, std which is 1.0 / std
"""
super
().
__init__
()
assert
mean
.
shape
==
istd
.
shape
self
.
norm_var
=
norm_var
# The buffer can be accessed from this module using self.mean
self
.
register_buffer
(
"mean"
,
mean
)
self
.
register_buffer
(
"istd"
,
istd
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
"""
Args:
x (torch.Tensor): (batch, max_len, feat_dim)
Returns:
(torch.Tensor): normalized feature
"""
x
=
x
-
self
.
mean
if
self
.
norm_var
:
x
=
x
*
self
.
istd
return
x
examples/aishell/s0/wenet/transformer/convolution.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""ConvolutionModule definition."""
from
typing
import
Tuple
import
torch
from
torch
import
nn
from
typeguard
import
check_argument_types
class
ConvolutionModule
(
nn
.
Module
):
"""ConvolutionModule in Conformer model."""
def
__init__
(
self
,
channels
:
int
,
kernel_size
:
int
=
15
,
activation
:
nn
.
Module
=
nn
.
ReLU
(),
norm
:
str
=
"batch_norm"
,
causal
:
bool
=
False
,
bias
:
bool
=
True
):
"""Construct an ConvolutionModule object.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernel size of conv layers.
causal (int): Whether use causal convolution or not
"""
assert
check_argument_types
()
super
().
__init__
()
self
.
pointwise_conv1
=
nn
.
Conv1d
(
channels
,
2
*
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
bias
,
)
# self.lorder is used to distinguish if it's a causal convolution,
# if self.lorder > 0: it's a causal convolution, the input will be
# padded with self.lorder frames on the left in forward.
# else: it's a symmetrical convolution
if
causal
:
padding
=
0
self
.
lorder
=
kernel_size
-
1
else
:
# kernel_size should be an odd number for none causal convolution
assert
(
kernel_size
-
1
)
%
2
==
0
padding
=
(
kernel_size
-
1
)
//
2
self
.
lorder
=
0
self
.
depthwise_conv
=
nn
.
Conv1d
(
channels
,
channels
,
kernel_size
,
stride
=
1
,
padding
=
padding
,
groups
=
channels
,
bias
=
bias
,
)
assert
norm
in
[
'batch_norm'
,
'layer_norm'
]
if
norm
==
"batch_norm"
:
self
.
use_layer_norm
=
False
self
.
norm
=
nn
.
BatchNorm1d
(
channels
)
else
:
self
.
use_layer_norm
=
True
self
.
norm
=
nn
.
LayerNorm
(
channels
)
self
.
pointwise_conv2
=
nn
.
Conv1d
(
channels
,
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
bias
,
)
self
.
activation
=
activation
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask_pad
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
)),
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
(0, 0, 0) means fake mask.
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, cache_t),
(0, 0, 0) meas fake cache.
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x
=
x
.
transpose
(
1
,
2
)
# (#batch, channels, time)
# mask batch padding
if
mask_pad
.
size
(
2
)
>
0
:
# time > 0
x
.
masked_fill_
(
~
mask_pad
,
0.0
)
if
self
.
lorder
>
0
:
if
cache
.
size
(
2
)
==
0
:
# cache_t == 0
x
=
nn
.
functional
.
pad
(
x
,
(
self
.
lorder
,
0
),
'constant'
,
0.0
)
else
:
assert
cache
.
size
(
0
)
==
x
.
size
(
0
)
# equal batch
assert
cache
.
size
(
1
)
==
x
.
size
(
1
)
# equal channel
x
=
torch
.
cat
((
cache
,
x
),
dim
=
2
)
assert
(
x
.
size
(
2
)
>
self
.
lorder
)
new_cache
=
x
[:,
:,
-
self
.
lorder
:]
else
:
# It's better we just return None if no cache is required,
# However, for JIT export, here we just fake one tensor instead of
# None.
new_cache
=
torch
.
zeros
((
0
,
0
,
0
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# GLU mechanism
x
=
self
.
pointwise_conv1
(
x
)
# (batch, 2*channel, dim)
x
=
nn
.
functional
.
glu
(
x
,
dim
=
1
)
# (batch, channel, dim)
# 1D Depthwise Conv
x
=
self
.
depthwise_conv
(
x
)
if
self
.
use_layer_norm
:
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
activation
(
self
.
norm
(
x
))
if
self
.
use_layer_norm
:
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
pointwise_conv2
(
x
)
# mask batch padding
if
mask_pad
.
size
(
2
)
>
0
:
# time > 0
x
.
masked_fill_
(
~
mask_pad
,
0.0
)
return
x
.
transpose
(
1
,
2
),
new_cache
examples/aishell/s0/wenet/transformer/ctc.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
import
torch
import
torch.nn.functional
as
F
from
typeguard
import
check_argument_types
class
CTC
(
torch
.
nn
.
Module
):
"""CTC module"""
def
__init__
(
self
,
odim
:
int
,
encoder_output_size
:
int
,
dropout_rate
:
float
=
0.0
,
reduce
:
bool
=
True
,
):
""" Construct CTC module
Args:
odim: dimension of outputs
encoder_output_size: number of encoder projection units
dropout_rate: dropout rate (0.0 ~ 1.0)
reduce: reduce the CTC loss into a scalar
"""
assert
check_argument_types
()
super
().
__init__
()
eprojs
=
encoder_output_size
self
.
dropout_rate
=
dropout_rate
self
.
ctc_lo
=
torch
.
nn
.
Linear
(
eprojs
,
odim
)
reduction_type
=
"sum"
if
reduce
else
"none"
self
.
ctc_loss
=
torch
.
nn
.
CTCLoss
(
reduction
=
reduction_type
)
def
forward
(
self
,
hs_pad
:
torch
.
Tensor
,
hlens
:
torch
.
Tensor
,
ys_pad
:
torch
.
Tensor
,
ys_lens
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Calculate CTC loss.
Args:
hs_pad: batch of padded hidden state sequences (B, Tmax, D)
hlens: batch of lengths of hidden state sequences (B)
ys_pad: batch of padded character id sequence tensor (B, Lmax)
ys_lens: batch of lengths of character sequence (B)
"""
# hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
ys_hat
=
self
.
ctc_lo
(
F
.
dropout
(
hs_pad
,
p
=
self
.
dropout_rate
))
# ys_hat: (B, L, D) -> (L, B, D)
ys_hat
=
ys_hat
.
transpose
(
0
,
1
)
ys_hat
=
ys_hat
.
log_softmax
(
2
)
loss
=
self
.
ctc_loss
(
ys_hat
,
ys_pad
,
hlens
,
ys_lens
)
# Batch-size average
loss
=
loss
/
ys_hat
.
size
(
1
)
return
loss
def
log_softmax
(
self
,
hs_pad
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""log_softmax of frame activations
Args:
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
"""
return
F
.
log_softmax
(
self
.
ctc_lo
(
hs_pad
),
dim
=
2
)
def
argmax
(
self
,
hs_pad
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""argmax of frame activations
Args:
torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
Returns:
torch.Tensor: argmax applied 2d tensor (B, Tmax)
"""
return
torch
.
argmax
(
self
.
ctc_lo
(
hs_pad
),
dim
=
2
)
examples/aishell/s0/wenet/transformer/decoder.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Decoder definition."""
from
typing
import
Tuple
,
List
,
Optional
import
torch
from
typeguard
import
check_argument_types
from
wenet.transformer.attention
import
MultiHeadedAttention
from
wenet.transformer.decoder_layer
import
DecoderLayer
from
wenet.transformer.embedding
import
PositionalEncoding
from
wenet.transformer.positionwise_feed_forward
import
PositionwiseFeedForward
from
wenet.utils.mask
import
(
subsequent_mask
,
make_pad_mask
)
class
TransformerDecoder
(
torch
.
nn
.
Module
):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the hidden units number of position-wise feedforward
num_blocks: the number of decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after: whether to concat attention layer's input and output
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def
__init__
(
self
,
vocab_size
:
int
,
encoder_output_size
:
int
,
attention_heads
:
int
=
4
,
linear_units
:
int
=
2048
,
num_blocks
:
int
=
6
,
dropout_rate
:
float
=
0.1
,
positional_dropout_rate
:
float
=
0.1
,
self_attention_dropout_rate
:
float
=
0.0
,
src_attention_dropout_rate
:
float
=
0.0
,
input_layer
:
str
=
"embed"
,
use_output_layer
:
bool
=
True
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
):
assert
check_argument_types
()
super
().
__init__
()
attention_dim
=
encoder_output_size
if
input_layer
==
"embed"
:
self
.
embed
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Embedding
(
vocab_size
,
attention_dim
),
PositionalEncoding
(
attention_dim
,
positional_dropout_rate
),
)
else
:
raise
ValueError
(
f
"only 'embed' is supported:
{
input_layer
}
"
)
self
.
normalize_before
=
normalize_before
self
.
after_norm
=
torch
.
nn
.
LayerNorm
(
attention_dim
,
eps
=
1e-5
)
self
.
use_output_layer
=
use_output_layer
self
.
output_layer
=
torch
.
nn
.
Linear
(
attention_dim
,
vocab_size
)
self
.
num_blocks
=
num_blocks
self
.
decoders
=
torch
.
nn
.
ModuleList
([
DecoderLayer
(
attention_dim
,
MultiHeadedAttention
(
attention_heads
,
attention_dim
,
self_attention_dropout_rate
),
MultiHeadedAttention
(
attention_heads
,
attention_dim
,
src_attention_dropout_rate
),
PositionwiseFeedForward
(
attention_dim
,
linear_units
,
dropout_rate
),
dropout_rate
,
normalize_before
,
concat_after
,
)
for
_
in
range
(
self
.
num_blocks
)
])
def
forward
(
self
,
memory
:
torch
.
Tensor
,
memory_mask
:
torch
.
Tensor
,
ys_in_pad
:
torch
.
Tensor
,
ys_in_lens
:
torch
.
Tensor
,
r_ys_in_pad
:
torch
.
Tensor
=
torch
.
empty
(
0
),
reverse_weight
:
float
=
0.0
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: not used in transformer decoder, in order to unify api
with bidirectional decoder
reverse_weight: not used in transformer decoder, in order to unify
api with bidirectional decode
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out,
vocab_size) if use_output_layer is True,
torch.tensor(0.0), in order to unify api with bidirectional decoder
olens: (batch, )
"""
tgt
=
ys_in_pad
maxlen
=
tgt
.
size
(
1
)
# tgt_mask: (B, 1, L)
tgt_mask
=
~
make_pad_mask
(
ys_in_lens
,
maxlen
).
unsqueeze
(
1
)
tgt_mask
=
tgt_mask
.
to
(
tgt
.
device
)
# m: (1, L, L)
m
=
subsequent_mask
(
tgt_mask
.
size
(
-
1
),
device
=
tgt_mask
.
device
).
unsqueeze
(
0
)
# tgt_mask: (B, L, L)
tgt_mask
=
tgt_mask
&
m
x
,
_
=
self
.
embed
(
tgt
)
for
layer
in
self
.
decoders
:
x
,
tgt_mask
,
memory
,
memory_mask
=
layer
(
x
,
tgt_mask
,
memory
,
memory_mask
)
if
self
.
normalize_before
:
x
=
self
.
after_norm
(
x
)
if
self
.
use_output_layer
:
x
=
self
.
output_layer
(
x
)
olens
=
tgt_mask
.
sum
(
1
)
return
x
,
torch
.
tensor
(
0.0
),
olens
def
forward_one_step
(
self
,
memory
:
torch
.
Tensor
,
memory_mask
:
torch
.
Tensor
,
tgt
:
torch
.
Tensor
,
tgt_mask
:
torch
.
Tensor
,
cache
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
"""Forward one step.
This is only used for decoding.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
x
,
_
=
self
.
embed
(
tgt
)
new_cache
=
[]
for
i
,
decoder
in
enumerate
(
self
.
decoders
):
if
cache
is
None
:
c
=
None
else
:
c
=
cache
[
i
]
x
,
tgt_mask
,
memory
,
memory_mask
=
decoder
(
x
,
tgt_mask
,
memory
,
memory_mask
,
cache
=
c
)
new_cache
.
append
(
x
)
if
self
.
normalize_before
:
y
=
self
.
after_norm
(
x
[:,
-
1
])
else
:
y
=
x
[:,
-
1
]
if
self
.
use_output_layer
:
y
=
torch
.
log_softmax
(
self
.
output_layer
(
y
),
dim
=-
1
)
return
y
,
new_cache
class
BiTransformerDecoder
(
torch
.
nn
.
Module
):
"""Base class of Transfomer decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of attention
attention_heads: the number of heads of multi head attention
linear_units: the hidden units number of position-wise feedforward
num_blocks: the number of decoder blocks
r_num_blocks: the number of right to left decoder blocks
dropout_rate: dropout rate
self_attention_dropout_rate: dropout rate for attention
input_layer: input layer type
use_output_layer: whether to use output layer
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
normalize_before:
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after: whether to concat attention layer's input and output
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def
__init__
(
self
,
vocab_size
:
int
,
encoder_output_size
:
int
,
attention_heads
:
int
=
4
,
linear_units
:
int
=
2048
,
num_blocks
:
int
=
6
,
r_num_blocks
:
int
=
0
,
dropout_rate
:
float
=
0.1
,
positional_dropout_rate
:
float
=
0.1
,
self_attention_dropout_rate
:
float
=
0.0
,
src_attention_dropout_rate
:
float
=
0.0
,
input_layer
:
str
=
"embed"
,
use_output_layer
:
bool
=
True
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
):
assert
check_argument_types
()
super
().
__init__
()
self
.
left_decoder
=
TransformerDecoder
(
vocab_size
,
encoder_output_size
,
attention_heads
,
linear_units
,
num_blocks
,
dropout_rate
,
positional_dropout_rate
,
self_attention_dropout_rate
,
src_attention_dropout_rate
,
input_layer
,
use_output_layer
,
normalize_before
,
concat_after
)
self
.
right_decoder
=
TransformerDecoder
(
vocab_size
,
encoder_output_size
,
attention_heads
,
linear_units
,
r_num_blocks
,
dropout_rate
,
positional_dropout_rate
,
self_attention_dropout_rate
,
src_attention_dropout_rate
,
input_layer
,
use_output_layer
,
normalize_before
,
concat_after
)
def
forward
(
self
,
memory
:
torch
.
Tensor
,
memory_mask
:
torch
.
Tensor
,
ys_in_pad
:
torch
.
Tensor
,
ys_in_lens
:
torch
.
Tensor
,
r_ys_in_pad
:
torch
.
Tensor
,
reverse_weight
:
float
=
0.0
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Forward decoder.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
ys_in_lens: input lengths of this batch (batch)
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
used for right to left decoder
reverse_weight: used for right to left decoder
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out,
vocab_size) if use_output_layer is True,
r_x: x: decoded token score (right to left decoder)
before softmax (batch, maxlen_out, vocab_size)
if use_output_layer is True,
olens: (batch, )
"""
l_x
,
_
,
olens
=
self
.
left_decoder
(
memory
,
memory_mask
,
ys_in_pad
,
ys_in_lens
)
r_x
=
torch
.
tensor
(
0.0
)
if
reverse_weight
>
0.0
:
r_x
,
_
,
olens
=
self
.
right_decoder
(
memory
,
memory_mask
,
r_ys_in_pad
,
ys_in_lens
)
return
l_x
,
r_x
,
olens
def
forward_one_step
(
self
,
memory
:
torch
.
Tensor
,
memory_mask
:
torch
.
Tensor
,
tgt
:
torch
.
Tensor
,
tgt_mask
:
torch
.
Tensor
,
cache
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
"""Forward one step.
This is only used for decoding.
Args:
memory: encoded memory, float32 (batch, maxlen_in, feat)
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
tgt: input token ids, int64 (batch, maxlen_out)
tgt_mask: input token mask, (batch, maxlen_out)
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
cache: cached output list of (batch, max_time_out-1, size)
Returns:
y, cache: NN output value and cache per `self.decoders`.
y.shape` is (batch, maxlen_out, token)
"""
return
self
.
left_decoder
.
forward_one_step
(
memory
,
memory_mask
,
tgt
,
tgt_mask
,
cache
)
examples/aishell/s0/wenet/transformer/decoder_layer.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoder self-attention layer definition."""
from
typing
import
Optional
,
Tuple
import
torch
from
torch
import
nn
class
DecoderLayer
(
nn
.
Module
):
"""Single decoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
src_attn (torch.nn.Module): Inter-attention module instance.
`MultiHeadedAttention` instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: to use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's inpu
and output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def
__init__
(
self
,
size
:
int
,
self_attn
:
nn
.
Module
,
src_attn
:
nn
.
Module
,
feed_forward
:
nn
.
Module
,
dropout_rate
:
float
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
):
"""Construct an DecoderLayer object."""
super
().
__init__
()
self
.
size
=
size
self
.
self_attn
=
self_attn
self
.
src_attn
=
src_attn
self
.
feed_forward
=
feed_forward
self
.
norm1
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
self
.
norm2
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
self
.
norm3
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
if
self
.
concat_after
:
self
.
concat_linear1
=
nn
.
Linear
(
size
+
size
,
size
)
self
.
concat_linear2
=
nn
.
Linear
(
size
+
size
,
size
)
else
:
self
.
concat_linear1
=
nn
.
Identity
()
self
.
concat_linear2
=
nn
.
Identity
()
def
forward
(
self
,
tgt
:
torch
.
Tensor
,
tgt_mask
:
torch
.
Tensor
,
memory
:
torch
.
Tensor
,
memory_mask
:
torch
.
Tensor
,
cache
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor
(#batch, maxlen_out).
memory (torch.Tensor): Encoded memory
(#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask
(#batch, maxlen_in).
cache (torch.Tensor): cached tensors.
(#batch, maxlen_out - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, maxlen_out, size).
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
residual
=
tgt
if
self
.
normalize_before
:
tgt
=
self
.
norm1
(
tgt
)
if
cache
is
None
:
tgt_q
=
tgt
tgt_q_mask
=
tgt_mask
else
:
# compute only the last frame query keeping dim: max_time_out -> 1
assert
cache
.
shape
==
(
tgt
.
shape
[
0
],
tgt
.
shape
[
1
]
-
1
,
self
.
size
,
),
"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q
=
tgt
[:,
-
1
:,
:]
residual
=
residual
[:,
-
1
:,
:]
tgt_q_mask
=
tgt_mask
[:,
-
1
:,
:]
if
self
.
concat_after
:
tgt_concat
=
torch
.
cat
(
(
tgt_q
,
self
.
self_attn
(
tgt_q
,
tgt
,
tgt
,
tgt_q_mask
)[
0
]),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear1
(
tgt_concat
)
else
:
x
=
residual
+
self
.
dropout
(
self
.
self_attn
(
tgt_q
,
tgt
,
tgt
,
tgt_q_mask
)[
0
])
if
not
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
if
self
.
concat_after
:
x_concat
=
torch
.
cat
(
(
x
,
self
.
src_attn
(
x
,
memory
,
memory
,
memory_mask
)[
0
]),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear2
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
self
.
src_attn
(
x
,
memory
,
memory
,
memory_mask
)[
0
])
if
not
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm3
(
x
)
x
=
residual
+
self
.
dropout
(
self
.
feed_forward
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm3
(
x
)
if
cache
is
not
None
:
x
=
torch
.
cat
([
cache
,
x
],
dim
=
1
)
return
x
,
tgt_mask
,
memory
,
memory_mask
examples/aishell/s0/wenet/transformer/embedding.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Positonal Encoding Module."""
import
math
from
typing
import
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
class
PositionalEncoding
(
torch
.
nn
.
Module
):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
,
max_len
:
int
=
5000
,
reverse
:
bool
=
False
):
"""Construct an PositionalEncoding object."""
super
().
__init__
()
self
.
d_model
=
d_model
self
.
xscale
=
math
.
sqrt
(
self
.
d_model
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
max_len
=
max_len
self
.
pe
=
torch
.
zeros
(
self
.
max_len
,
self
.
d_model
)
position
=
torch
.
arange
(
0
,
self
.
max_len
,
dtype
=
torch
.
float32
).
unsqueeze
(
1
)
div_term
=
torch
.
exp
(
torch
.
arange
(
0
,
self
.
d_model
,
2
,
dtype
=
torch
.
float32
)
*
-
(
math
.
log
(
10000.0
)
/
self
.
d_model
))
self
.
pe
[:,
0
::
2
]
=
torch
.
sin
(
position
*
div_term
)
self
.
pe
[:,
1
::
2
]
=
torch
.
cos
(
position
*
div_term
)
self
.
pe
=
self
.
pe
.
unsqueeze
(
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
offset
:
Union
[
int
,
torch
.
Tensor
]
=
0
)
\
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
offset (int, torch.tensor): position offset
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
torch.Tensor: for compatibility to RelPositionalEncoding
"""
self
.
pe
=
self
.
pe
.
to
(
x
.
device
)
pos_emb
=
self
.
position_encoding
(
offset
,
x
.
size
(
1
),
False
)
x
=
x
*
self
.
xscale
+
pos_emb
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
def
position_encoding
(
self
,
offset
:
Union
[
int
,
torch
.
Tensor
],
size
:
int
,
apply_dropout
:
bool
=
True
)
->
torch
.
Tensor
:
""" For getting encoding in a streaming fashion
Attention!!!!!
we apply dropout only once at the whole utterance level in a none
streaming way, but will call this function several times with
increasing input size in a streaming scenario, so the dropout will
be applied several times.
Args:
offset (int or torch.tensor): start offset
size (int): required size of position encoding
Returns:
torch.Tensor: Corresponding encoding
"""
# How to subscript a Union type:
# https://github.com/pytorch/pytorch/issues/69434
if
isinstance
(
offset
,
int
):
assert
offset
+
size
<
self
.
max_len
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
size
]
elif
isinstance
(
offset
,
torch
.
Tensor
)
and
offset
.
dim
()
==
0
:
# scalar
assert
offset
+
size
<
self
.
max_len
pos_emb
=
self
.
pe
[:,
offset
:
offset
+
size
]
else
:
# for batched streaming decoding on GPU
assert
torch
.
max
(
offset
)
+
size
<
self
.
max_len
index
=
offset
.
unsqueeze
(
1
)
+
\
torch
.
arange
(
0
,
size
).
to
(
offset
.
device
)
# B X T
flag
=
index
>
0
# remove negative offset
index
=
index
*
flag
pos_emb
=
F
.
embedding
(
index
,
self
.
pe
[
0
])
# B X T X d_model
if
apply_dropout
:
pos_emb
=
self
.
dropout
(
pos_emb
)
return
pos_emb
class
RelPositionalEncoding
(
PositionalEncoding
):
"""Relative positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
,
max_len
:
int
=
5000
):
"""Initialize class."""
super
().
__init__
(
d_model
,
dropout_rate
,
max_len
,
reverse
=
True
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
offset
:
Union
[
int
,
torch
.
Tensor
]
=
0
)
\
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Positional embedding tensor (1, time, `*`).
"""
self
.
pe
=
self
.
pe
.
to
(
x
.
device
)
x
=
x
*
self
.
xscale
pos_emb
=
self
.
position_encoding
(
offset
,
x
.
size
(
1
),
False
)
return
self
.
dropout
(
x
),
self
.
dropout
(
pos_emb
)
class
NoPositionalEncoding
(
torch
.
nn
.
Module
):
""" No position encoding
"""
def
__init__
(
self
,
d_model
:
int
,
dropout_rate
:
float
):
super
().
__init__
()
self
.
d_model
=
d_model
self
.
dropout
=
torch
.
nn
.
Dropout
(
p
=
dropout_rate
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
offset
:
Union
[
int
,
torch
.
Tensor
]
=
0
)
\
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
""" Just return zero vector for interface compatibility
"""
pos_emb
=
torch
.
zeros
(
1
,
x
.
size
(
1
),
self
.
d_model
).
to
(
x
.
device
)
return
self
.
dropout
(
x
),
pos_emb
def
position_encoding
(
self
,
offset
:
Union
[
int
,
torch
.
Tensor
],
size
:
int
)
->
torch
.
Tensor
:
return
torch
.
zeros
(
1
,
size
,
self
.
d_model
)
examples/aishell/s0/wenet/transformer/encoder.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
from
typing
import
Tuple
import
torch
from
typeguard
import
check_argument_types
from
wenet.transformer.attention
import
MultiHeadedAttention
from
wenet.transformer.attention
import
RelPositionMultiHeadedAttention
from
wenet.transformer.convolution
import
ConvolutionModule
from
wenet.transformer.embedding
import
PositionalEncoding
from
wenet.transformer.embedding
import
RelPositionalEncoding
from
wenet.transformer.embedding
import
NoPositionalEncoding
from
wenet.transformer.encoder_layer
import
TransformerEncoderLayer
from
wenet.transformer.encoder_layer
import
ConformerEncoderLayer
from
wenet.transformer.positionwise_feed_forward
import
PositionwiseFeedForward
from
wenet.transformer.subsampling
import
Conv2dSubsampling4
from
wenet.transformer.subsampling
import
Conv2dSubsampling6
from
wenet.transformer.subsampling
import
Conv2dSubsampling8
from
wenet.transformer.subsampling
import
LinearNoSubsampling
from
wenet.utils.common
import
get_activation
from
wenet.utils.mask
import
make_pad_mask
from
wenet.utils.mask
import
add_optional_chunk_mask
class
BaseEncoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
=
256
,
attention_heads
:
int
=
4
,
linear_units
:
int
=
2048
,
num_blocks
:
int
=
6
,
dropout_rate
:
float
=
0.1
,
positional_dropout_rate
:
float
=
0.1
,
attention_dropout_rate
:
float
=
0.0
,
input_layer
:
str
=
"conv2d"
,
pos_enc_layer_type
:
str
=
"abs_pos"
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
static_chunk_size
:
int
=
0
,
use_dynamic_chunk
:
bool
=
False
,
global_cmvn
:
torch
.
nn
.
Module
=
None
,
use_dynamic_left_chunk
:
bool
=
False
,
):
"""
Args:
input_size (int): input dim
output_size (int): dimension of attention
attention_heads (int): the number of heads of multi head attention
linear_units (int): the hidden units number of position-wise feed
forward
num_blocks (int): the number of decoder blocks
dropout_rate (float): dropout rate
attention_dropout_rate (float): dropout rate in attention
positional_dropout_rate (float): dropout rate after adding
positional encoding
input_layer (str): input layer type.
optional [linear, conv2d, conv2d6, conv2d8]
pos_enc_layer_type (str): Encoder positional encoding layer type.
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
normalize_before (bool):
True: use layer_norm before each sub-block of a layer.
False: use layer_norm after each sub-block of a layer.
concat_after (bool): whether to concat attention layer's input
and output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
static_chunk_size (int): chunk size for static chunk training and
decoding
use_dynamic_chunk (bool): whether use dynamic chunk size for
training or not, You can only use fixed chunk(chunk_size > 0)
or dyanmic chunk size(use_dynamic_chunk = True)
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
dynamic chunk training
"""
assert
check_argument_types
()
super
().
__init__
()
self
.
_output_size
=
output_size
if
pos_enc_layer_type
==
"abs_pos"
:
pos_enc_class
=
PositionalEncoding
elif
pos_enc_layer_type
==
"rel_pos"
:
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"no_pos"
:
pos_enc_class
=
NoPositionalEncoding
else
:
raise
ValueError
(
"unknown pos_enc_layer: "
+
pos_enc_layer_type
)
if
input_layer
==
"linear"
:
subsampling_class
=
LinearNoSubsampling
elif
input_layer
==
"conv2d"
:
subsampling_class
=
Conv2dSubsampling4
elif
input_layer
==
"conv2d6"
:
subsampling_class
=
Conv2dSubsampling6
elif
input_layer
==
"conv2d8"
:
subsampling_class
=
Conv2dSubsampling8
else
:
raise
ValueError
(
"unknown input_layer: "
+
input_layer
)
self
.
global_cmvn
=
global_cmvn
self
.
embed
=
subsampling_class
(
input_size
,
output_size
,
dropout_rate
,
pos_enc_class
(
output_size
,
positional_dropout_rate
),
)
self
.
normalize_before
=
normalize_before
self
.
after_norm
=
torch
.
nn
.
LayerNorm
(
output_size
,
eps
=
1e-5
)
self
.
static_chunk_size
=
static_chunk_size
self
.
use_dynamic_chunk
=
use_dynamic_chunk
self
.
use_dynamic_left_chunk
=
use_dynamic_left_chunk
def
output_size
(
self
)
->
int
:
return
self
.
_output_size
def
forward
(
self
,
xs
:
torch
.
Tensor
,
xs_lens
:
torch
.
Tensor
,
decoding_chunk_size
:
int
=
0
,
num_decoding_left_chunks
:
int
=
-
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Embed positions in tensor.
Args:
xs: padded input tensor (B, T, D)
xs_lens: input length (B)
decoding_chunk_size: decoding chunk size for dynamic chunk
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
encoder output tensor xs, and subsampled masks
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
masks: torch.Tensor batch padding mask after subsample
(B, 1, T' ~= T/subsample_rate)
"""
T
=
xs
.
size
(
1
)
masks
=
~
make_pad_mask
(
xs_lens
,
T
).
unsqueeze
(
1
)
# (B, 1, T)
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
xs
,
pos_emb
,
masks
=
self
.
embed
(
xs
,
masks
)
mask_pad
=
masks
# (B, 1, T/subsample_rate)
chunk_masks
=
add_optional_chunk_mask
(
xs
,
masks
,
self
.
use_dynamic_chunk
,
self
.
use_dynamic_left_chunk
,
decoding_chunk_size
,
self
.
static_chunk_size
,
num_decoding_left_chunks
)
for
layer
in
self
.
encoders
:
xs
,
chunk_masks
,
_
,
_
=
layer
(
xs
,
chunk_masks
,
pos_emb
,
mask_pad
)
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return
xs
,
masks
def
forward_chunk
(
self
,
xs
:
torch
.
Tensor
,
offset
:
int
,
required_cache_size
:
int
,
att_cache
:
torch
.
Tensor
=
torch
.
zeros
(
0
,
0
,
0
,
0
),
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
(
0
,
0
,
0
,
0
),
att_mask
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate +
\
subsample.right_context + 1`
offset (int): current offset in encoder output time stamp
required_cache_size (int): cache size required for next chunk
compuation
>=0: actual cache size
<0: means all history cache is required
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(elayers, b=1, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
dynamic shape (elayers, head, ?, d_k * 2)
depending on required_cache_size.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
assert
xs
.
size
(
0
)
==
1
# tmp_masks is just for interface compatibility
tmp_masks
=
torch
.
ones
(
1
,
xs
.
size
(
1
),
device
=
xs
.
device
,
dtype
=
torch
.
bool
)
tmp_masks
=
tmp_masks
.
unsqueeze
(
1
)
if
self
.
global_cmvn
is
not
None
:
xs
=
self
.
global_cmvn
(
xs
)
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
xs
,
pos_emb
,
_
=
self
.
embed
(
xs
,
tmp_masks
,
offset
)
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
elayers
,
cache_t1
=
att_cache
.
size
(
0
),
att_cache
.
size
(
2
)
chunk_size
=
xs
.
size
(
1
)
attention_key_size
=
cache_t1
+
chunk_size
pos_emb
=
self
.
embed
.
position_encoding
(
offset
=
offset
-
cache_t1
,
size
=
attention_key_size
)
if
required_cache_size
<
0
:
next_cache_start
=
0
elif
required_cache_size
==
0
:
next_cache_start
=
attention_key_size
else
:
next_cache_start
=
max
(
attention_key_size
-
required_cache_size
,
0
)
r_att_cache
=
[]
r_cnn_cache
=
[]
for
i
,
layer
in
enumerate
(
self
.
encoders
):
# NOTE(xcsong): Before layer.forward
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
xs
,
_
,
new_att_cache
,
new_cnn_cache
=
layer
(
xs
,
att_mask
,
pos_emb
,
att_cache
=
att_cache
[
i
:
i
+
1
]
if
elayers
>
0
else
att_cache
,
cnn_cache
=
cnn_cache
[
i
]
if
cnn_cache
.
size
(
0
)
>
0
else
cnn_cache
)
# NOTE(xcsong): After layer.forward
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
r_att_cache
.
append
(
new_att_cache
[:,
:,
next_cache_start
:,
:])
r_cnn_cache
.
append
(
new_cnn_cache
.
unsqueeze
(
0
))
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
# ? may be larger than cache_t1, it depends on required_cache_size
r_att_cache
=
torch
.
cat
(
r_att_cache
,
dim
=
0
)
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
r_cnn_cache
=
torch
.
cat
(
r_cnn_cache
,
dim
=
0
)
return
(
xs
,
r_att_cache
,
r_cnn_cache
)
def
forward_chunk_by_chunk
(
self
,
xs
:
torch
.
Tensor
,
decoding_chunk_size
:
int
,
num_decoding_left_chunks
:
int
=
-
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
""" Forward input chunk by chunk with chunk_size like a streaming
fashion
Here we should pay special attention to computation cache in the
streaming style forward chunk by chunk. Three things should be taken
into account for computation in the current network:
1. transformer/conformer encoder layers output cache
2. convolution in conformer
3. convolution in subsampling
However, we don't implement subsampling cache for:
1. We can control subsampling module to output the right result by
overlapping input instead of cache left context, even though it
wastes some computation, but subsampling only takes a very
small fraction of computation in the whole model.
2. Typically, there are several covolution layers with subsampling
in subsampling module, it is tricky and complicated to do cache
with different convolution layers with different subsampling
rate.
3. Currently, nn.Sequential is used to stack all the convolution
layers in subsampling, we need to rewrite it to make it work
with cache, which is not prefered.
Args:
xs (torch.Tensor): (1, max_len, dim)
chunk_size (int): decoding chunk size
"""
assert
decoding_chunk_size
>
0
# The model is trained by static or dynamic chunk
assert
self
.
static_chunk_size
>
0
or
self
.
use_dynamic_chunk
subsampling
=
self
.
embed
.
subsampling_rate
context
=
self
.
embed
.
right_context
+
1
# Add current frame
stride
=
subsampling
*
decoding_chunk_size
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
num_frames
=
xs
.
size
(
1
)
att_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
),
device
=
xs
.
device
)
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
),
device
=
xs
.
device
)
outputs
=
[]
offset
=
0
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
# Feed forward overlap input step by step
for
cur
in
range
(
0
,
num_frames
-
context
+
1
,
stride
):
end
=
min
(
cur
+
decoding_window
,
num_frames
)
chunk_xs
=
xs
[:,
cur
:
end
,
:]
(
y
,
att_cache
,
cnn_cache
)
=
self
.
forward_chunk
(
chunk_xs
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
)
outputs
.
append
(
y
)
offset
+=
y
.
size
(
1
)
ys
=
torch
.
cat
(
outputs
,
1
)
masks
=
torch
.
ones
((
1
,
1
,
ys
.
size
(
1
)),
device
=
ys
.
device
,
dtype
=
torch
.
bool
)
return
ys
,
masks
class
TransformerEncoder
(
BaseEncoder
):
"""Transformer encoder module."""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
=
256
,
attention_heads
:
int
=
4
,
linear_units
:
int
=
2048
,
num_blocks
:
int
=
6
,
dropout_rate
:
float
=
0.1
,
positional_dropout_rate
:
float
=
0.1
,
attention_dropout_rate
:
float
=
0.0
,
input_layer
:
str
=
"conv2d"
,
pos_enc_layer_type
:
str
=
"abs_pos"
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
static_chunk_size
:
int
=
0
,
use_dynamic_chunk
:
bool
=
False
,
global_cmvn
:
torch
.
nn
.
Module
=
None
,
use_dynamic_left_chunk
:
bool
=
False
,
):
""" Construct TransformerEncoder
See Encoder for the meaning of each parameter.
"""
assert
check_argument_types
()
super
().
__init__
(
input_size
,
output_size
,
attention_heads
,
linear_units
,
num_blocks
,
dropout_rate
,
positional_dropout_rate
,
attention_dropout_rate
,
input_layer
,
pos_enc_layer_type
,
normalize_before
,
concat_after
,
static_chunk_size
,
use_dynamic_chunk
,
global_cmvn
,
use_dynamic_left_chunk
)
self
.
encoders
=
torch
.
nn
.
ModuleList
([
TransformerEncoderLayer
(
output_size
,
MultiHeadedAttention
(
attention_heads
,
output_size
,
attention_dropout_rate
),
PositionwiseFeedForward
(
output_size
,
linear_units
,
dropout_rate
),
dropout_rate
,
normalize_before
,
concat_after
)
for
_
in
range
(
num_blocks
)
])
class
ConformerEncoder
(
BaseEncoder
):
"""Conformer encoder module."""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
=
256
,
attention_heads
:
int
=
4
,
linear_units
:
int
=
2048
,
num_blocks
:
int
=
6
,
dropout_rate
:
float
=
0.1
,
positional_dropout_rate
:
float
=
0.1
,
attention_dropout_rate
:
float
=
0.0
,
input_layer
:
str
=
"conv2d"
,
pos_enc_layer_type
:
str
=
"rel_pos"
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
static_chunk_size
:
int
=
0
,
use_dynamic_chunk
:
bool
=
False
,
global_cmvn
:
torch
.
nn
.
Module
=
None
,
use_dynamic_left_chunk
:
bool
=
False
,
positionwise_conv_kernel_size
:
int
=
1
,
macaron_style
:
bool
=
True
,
selfattention_layer_type
:
str
=
"rel_selfattn"
,
activation_type
:
str
=
"swish"
,
use_cnn_module
:
bool
=
True
,
cnn_module_kernel
:
int
=
15
,
causal
:
bool
=
False
,
cnn_module_norm
:
str
=
"batch_norm"
,
):
"""Construct ConformerEncoder
Args:
input_size to use_dynamic_chunk, see in BaseEncoder
positionwise_conv_kernel_size (int): Kernel size of positionwise
conv1d layer.
macaron_style (bool): Whether to use macaron style for
positionwise layer.
selfattention_layer_type (str): Encoder attention layer type,
the parameter has no effect now, it's just for configure
compatibility.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): whether to use causal convolution or not.
"""
assert
check_argument_types
()
super
().
__init__
(
input_size
,
output_size
,
attention_heads
,
linear_units
,
num_blocks
,
dropout_rate
,
positional_dropout_rate
,
attention_dropout_rate
,
input_layer
,
pos_enc_layer_type
,
normalize_before
,
concat_after
,
static_chunk_size
,
use_dynamic_chunk
,
global_cmvn
,
use_dynamic_left_chunk
)
activation
=
get_activation
(
activation_type
)
# self-attention module definition
if
pos_enc_layer_type
!=
"rel_pos"
:
encoder_selfattn_layer
=
MultiHeadedAttention
else
:
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
output_size
,
attention_dropout_rate
,
)
# feed-forward module definition
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer_args
=
(
output_size
,
linear_units
,
dropout_rate
,
activation
,
)
# convolution module definition
convolution_layer
=
ConvolutionModule
convolution_layer_args
=
(
output_size
,
cnn_module_kernel
,
activation
,
cnn_module_norm
,
causal
)
self
.
encoders
=
torch
.
nn
.
ModuleList
([
ConformerEncoderLayer
(
output_size
,
encoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
)
if
macaron_style
else
None
,
convolution_layer
(
*
convolution_layer_args
)
if
use_cnn_module
else
None
,
dropout_rate
,
normalize_before
,
concat_after
,
)
for
_
in
range
(
num_blocks
)
])
examples/aishell/s0/wenet/transformer/encoder_layer.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder self-attention layer definition."""
from
typing
import
Optional
,
Tuple
import
torch
from
torch
import
nn
class
TransformerEncoderLayer
(
nn
.
Module
):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: to use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def
__init__
(
self
,
size
:
int
,
self_attn
:
torch
.
nn
.
Module
,
feed_forward
:
torch
.
nn
.
Module
,
dropout_rate
:
float
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
):
"""Construct an EncoderLayer object."""
super
().
__init__
()
self
.
self_attn
=
self_attn
self
.
feed_forward
=
feed_forward
self
.
norm1
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
self
.
norm2
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
if
concat_after
:
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
else
:
self
.
concat_linear
=
nn
.
Identity
()
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
pos_emb
:
torch
.
Tensor
,
mask_pad
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
att_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
)),
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
)),
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): just for interface compatibility
to ConformerEncoderLayer
mask_pad (torch.Tensor): does not used in transformer layer,
just for unified api with conformer.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2), not used here, it's for interface
compatibility to ConformerEncoderLayer.
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
"""
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
x_att
,
new_att_cache
=
self
.
self_attn
(
x
,
x
,
x
,
mask
,
cache
=
att_cache
)
if
self
.
concat_after
:
x_concat
=
torch
.
cat
((
x
,
x_att
),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
x_att
)
if
not
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
x
=
residual
+
self
.
dropout
(
self
.
feed_forward
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
fake_cnn_cache
=
torch
.
zeros
((
0
,
0
,
0
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
x
,
mask
,
new_att_cache
,
fake_cnn_cache
class
ConformerEncoderLayer
(
nn
.
Module
):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
instance can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward` instance can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
instance.
`PositionwiseFeedForward` instance can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def
__init__
(
self
,
size
:
int
,
self_attn
:
torch
.
nn
.
Module
,
feed_forward
:
Optional
[
nn
.
Module
]
=
None
,
feed_forward_macaron
:
Optional
[
nn
.
Module
]
=
None
,
conv_module
:
Optional
[
nn
.
Module
]
=
None
,
dropout_rate
:
float
=
0.1
,
normalize_before
:
bool
=
True
,
concat_after
:
bool
=
False
,
):
"""Construct an EncoderLayer object."""
super
().
__init__
()
self
.
self_attn
=
self_attn
self
.
feed_forward
=
feed_forward
self
.
feed_forward_macaron
=
feed_forward_macaron
self
.
conv_module
=
conv_module
self
.
norm_ff
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
# for the FNN module
self
.
norm_mha
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
# for the MHA module
if
feed_forward_macaron
is
not
None
:
self
.
norm_ff_macaron
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
self
.
ff_scale
=
0.5
else
:
self
.
ff_scale
=
1.0
if
self
.
conv_module
is
not
None
:
self
.
norm_conv
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
# for the CNN module
self
.
norm_final
=
nn
.
LayerNorm
(
size
,
eps
=
1e-5
)
# for the final output of the block
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
if
self
.
concat_after
:
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
else
:
self
.
concat_linear
=
nn
.
Identity
()
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
pos_emb
:
torch
.
Tensor
,
mask_pad
:
torch
.
Tensor
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
),
att_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
)),
cnn_cache
:
torch
.
Tensor
=
torch
.
zeros
((
0
,
0
,
0
,
0
)),
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, time, size)
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
(0, 0, 0) means fake mask.
pos_emb (torch.Tensor): positional encoding, must not be None
for ConformerEncoderLayer.
mask_pad (torch.Tensor): batch padding mask used for conv module.
(#batch, 1,time), (0, 0, 0) means fake mask.
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time, time).
torch.Tensor: att_cache tensor,
(#batch=1, head, cache_t1 + time, d_k * 2).
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
"""
# whether to use macaron style
if
self
.
feed_forward_macaron
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward_macaron
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
# multi-headed self-attention module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_mha
(
x
)
x_att
,
new_att_cache
=
self
.
self_attn
(
x
,
x
,
x
,
mask
,
pos_emb
,
att_cache
)
if
self
.
concat_after
:
x_concat
=
torch
.
cat
((
x
,
x_att
),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
x_att
)
if
not
self
.
normalize_before
:
x
=
self
.
norm_mha
(
x
)
# convolution module
# Fake new cnn cache here, and then change it in conv_module
new_cnn_cache
=
torch
.
zeros
((
0
,
0
,
0
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
if
self
.
conv_module
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
x
,
new_cnn_cache
=
self
.
conv_module
(
x
,
mask_pad
,
cnn_cache
)
x
=
residual
+
self
.
dropout
(
x
)
if
not
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
# feed forward module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff
(
x
)
if
self
.
conv_module
is
not
None
:
x
=
self
.
norm_final
(
x
)
return
x
,
mask
,
new_att_cache
,
new_cnn_cache
examples/aishell/s0/wenet/transformer/label_smoothing_loss.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Label smoothing module."""
import
torch
from
torch
import
nn
class
LabelSmoothingLoss
(
nn
.
Module
):
"""Label-smoothing loss.
In a standard CE loss, the label's data distribution is:
[0,1,2] ->
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
]
In the smoothing version CE Loss,some probabilities
are taken from the true label prob (1.0) and are divided
among other labels.
e.g.
smoothing=0.1
[0,1,2] ->
[
[0.9, 0.05, 0.05],
[0.05, 0.9, 0.05],
[0.05, 0.05, 0.9],
]
Args:
size (int): the number of class
padding_idx (int): padding class id which will be ignored for loss
smoothing (float): smoothing rate (0.0 means the conventional CE)
normalize_length (bool):
normalize loss by sequence length if True
normalize loss by batch size if False
"""
def
__init__
(
self
,
size
:
int
,
padding_idx
:
int
,
smoothing
:
float
,
normalize_length
:
bool
=
False
):
"""Construct an LabelSmoothingLoss object."""
super
(
LabelSmoothingLoss
,
self
).
__init__
()
self
.
criterion
=
nn
.
KLDivLoss
(
reduction
=
"none"
)
self
.
padding_idx
=
padding_idx
self
.
confidence
=
1.0
-
smoothing
self
.
smoothing
=
smoothing
self
.
size
=
size
self
.
normalize_length
=
normalize_length
def
forward
(
self
,
x
:
torch
.
Tensor
,
target
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute loss between x and target.
The model outputs and data labels tensors are flatten to
(batch*seqlen, class) shape and a mask is applied to the
padding part which should not be calculated for loss.
Args:
x (torch.Tensor): prediction (batch, seqlen, class)
target (torch.Tensor):
target signal masked with self.padding_id (batch, seqlen)
Returns:
loss (torch.Tensor) : The KL loss, scalar float value
"""
assert
x
.
size
(
2
)
==
self
.
size
batch_size
=
x
.
size
(
0
)
x
=
x
.
view
(
-
1
,
self
.
size
)
target
=
target
.
view
(
-
1
)
# use zeros_like instead of torch.no_grad() for true_dist,
# since no_grad() can not be exported by JIT
true_dist
=
torch
.
zeros_like
(
x
)
true_dist
.
fill_
(
self
.
smoothing
/
(
self
.
size
-
1
))
ignore
=
target
==
self
.
padding_idx
# (B,)
total
=
len
(
target
)
-
ignore
.
sum
().
item
()
target
=
target
.
masked_fill
(
ignore
,
0
)
# avoid -1 index
true_dist
.
scatter_
(
1
,
target
.
unsqueeze
(
1
),
self
.
confidence
)
kl
=
self
.
criterion
(
torch
.
log_softmax
(
x
,
dim
=
1
),
true_dist
)
denom
=
total
if
self
.
normalize_length
else
batch_size
return
kl
.
masked_fill
(
ignore
.
unsqueeze
(
1
),
0
).
sum
()
/
denom
examples/aishell/s0/wenet/transformer/positionwise_feed_forward.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2019 Shigeki Karita
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Positionwise feed forward layer definition."""
import
torch
class
PositionwiseFeedForward
(
torch
.
nn
.
Module
):
"""Positionwise feed forward layer.
FeedForward are appied on each position of the sequence.
The output dim is same with the input dim.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""
def
__init__
(
self
,
idim
:
int
,
hidden_units
:
int
,
dropout_rate
:
float
,
activation
:
torch
.
nn
.
Module
=
torch
.
nn
.
ReLU
()):
"""Construct a PositionwiseFeedForward object."""
super
(
PositionwiseFeedForward
,
self
).
__init__
()
self
.
w_1
=
torch
.
nn
.
Linear
(
idim
,
hidden_units
)
self
.
activation
=
activation
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout_rate
)
self
.
w_2
=
torch
.
nn
.
Linear
(
hidden_units
,
idim
)
def
forward
(
self
,
xs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Forward function.
Args:
xs: input tensor (B, L, D)
Returns:
output tensor, (B, L, D)
"""
return
self
.
w_2
(
self
.
dropout
(
self
.
activation
(
self
.
w_1
(
xs
))))
examples/aishell/s0/wenet/transformer/subsampling.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Subsampling layer definition."""
from
typing
import
Tuple
,
Union
import
torch
class
BaseSubsampling
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
right_context
=
0
self
.
subsampling_rate
=
1
def
position_encoding
(
self
,
offset
:
Union
[
int
,
torch
.
Tensor
],
size
:
int
)
->
torch
.
Tensor
:
return
self
.
pos_enc
.
position_encoding
(
offset
,
size
)
class
LinearNoSubsampling
(
BaseSubsampling
):
"""Linear transform the input without subsampling
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
torch
.
nn
.
Module
):
"""Construct an linear object."""
super
().
__init__
()
self
.
out
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
idim
,
odim
),
torch
.
nn
.
LayerNorm
(
odim
,
eps
=
1e-5
),
torch
.
nn
.
Dropout
(
dropout_rate
),
)
self
.
pos_enc
=
pos_enc_class
self
.
right_context
=
0
self
.
subsampling_rate
=
1
def
forward
(
self
,
x
:
torch
.
Tensor
,
x_mask
:
torch
.
Tensor
,
offset
:
Union
[
int
,
torch
.
Tensor
]
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Input x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: linear input tensor (#batch, time', odim),
where time' = time .
torch.Tensor: linear input mask (#batch, 1, time'),
where time' = time .
"""
x
=
self
.
out
(
x
)
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
class
Conv2dSubsampling4
(
BaseSubsampling
):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
torch
.
nn
.
Module
):
"""Construct an Conv2dSubsampling4 object."""
super
().
__init__
()
self
.
conv
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
1
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv2d
(
odim
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
)
self
.
out
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
1
)
//
2
),
odim
))
self
.
pos_enc
=
pos_enc_class
# The right context for every conv layer is computed by:
# (kernel_size - 1) * frame_rate_of_this_layer
self
.
subsampling_rate
=
4
# 6 = (3 - 1) * 1 + (3 - 1) * 2
self
.
right_context
=
6
def
forward
(
self
,
x
:
torch
.
Tensor
,
x_mask
:
torch
.
Tensor
,
offset
:
Union
[
int
,
torch
.
Tensor
]
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
torch.Tensor: positional encoding
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c=1, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
x
.
size
()
x
=
self
.
out
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
))
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
[:,
:,
2
::
2
][:,
:,
2
::
2
]
class
Conv2dSubsampling6
(
BaseSubsampling
):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
torch
.
nn
.
Module
):
"""Construct an Conv2dSubsampling6 object."""
super
().
__init__
()
self
.
conv
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
1
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv2d
(
odim
,
odim
,
5
,
3
),
torch
.
nn
.
ReLU
(),
)
self
.
linear
=
torch
.
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
2
)
//
3
),
odim
)
self
.
pos_enc
=
pos_enc_class
# 10 = (3 - 1) * 1 + (5 - 1) * 2
self
.
subsampling_rate
=
6
self
.
right_context
=
10
def
forward
(
self
,
x
:
torch
.
Tensor
,
x_mask
:
torch
.
Tensor
,
offset
:
Union
[
int
,
torch
.
Tensor
]
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
torch.Tensor: positional encoding
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
x
.
size
()
x
=
self
.
linear
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
))
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
[:,
:,
2
::
2
][:,
:,
4
::
3
]
class
Conv2dSubsampling8
(
BaseSubsampling
):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
"""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
dropout_rate
:
float
,
pos_enc_class
:
torch
.
nn
.
Module
):
"""Construct an Conv2dSubsampling8 object."""
super
().
__init__
()
self
.
conv
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
1
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv2d
(
odim
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv2d
(
odim
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
)
self
.
linear
=
torch
.
nn
.
Linear
(
odim
*
((((
idim
-
1
)
//
2
-
1
)
//
2
-
1
)
//
2
),
odim
)
self
.
pos_enc
=
pos_enc_class
self
.
subsampling_rate
=
8
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
self
.
right_context
=
14
def
forward
(
self
,
x
:
torch
.
Tensor
,
x_mask
:
torch
.
Tensor
,
offset
:
Union
[
int
,
torch
.
Tensor
]
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
torch.Tensor: positional encoding
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
x
.
size
()
x
=
self
.
linear
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
))
x
,
pos_emb
=
self
.
pos_enc
(
x
,
offset
)
return
x
,
pos_emb
,
x_mask
[:,
:,
2
::
2
][:,
:,
2
::
2
][:,
:,
2
::
2
]
examples/aishell/s0/wenet/transformer/swish.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
# 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Swish() activation function for Conformer."""
import
torch
class
Swish
(
torch
.
nn
.
Module
):
"""Construct an Swish object."""
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Return Swish activation function."""
return
x
*
torch
.
sigmoid
(
x
)
examples/aishell/s0/wenet/utils/__pycache__/checkpoint.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
Prev
1
…
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment