Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
sunzhq2
yidong-infer
Commits
60a2c57a
Commit
60a2c57a
authored
Jan 27, 2026
by
sunzhq2
Committed by
xuxo
Jan 27, 2026
Browse files
update conformer
parent
4a699441
Changes
216
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4642 additions
and
0 deletions
+4642
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/streaming/__init__.py
...ild/lib/espnet/nets/pytorch_backend/streaming/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/streaming/segment.py
...uild/lib/espnet/nets/pytorch_backend/streaming/segment.py
+129
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/streaming/window.py
...build/lib/espnet/nets/pytorch_backend/streaming/window.py
+81
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/tacotron2/__init__.py
...ild/lib/espnet/nets/pytorch_backend/tacotron2/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/tacotron2/cbhg.py
...1/build/lib/espnet/nets/pytorch_backend/tacotron2/cbhg.py
+274
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/tacotron2/decoder.py
...uild/lib/espnet/nets/pytorch_backend/tacotron2/decoder.py
+675
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/tacotron2/encoder.py
...uild/lib/espnet/nets/pytorch_backend/tacotron2/encoder.py
+172
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/__init__.py
...ld/lib/espnet/nets/pytorch_backend/transducer/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/arguments.py
...d/lib/espnet/nets/pytorch_backend/transducer/arguments.py
+386
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/blocks.py
...uild/lib/espnet/nets/pytorch_backend/transducer/blocks.py
+536
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/conv1d_nets.py
...lib/espnet/nets/pytorch_backend/transducer/conv1d_nets.py
+252
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/custom_decoder.py
.../espnet/nets/pytorch_backend/transducer/custom_decoder.py
+293
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/custom_encoder.py
.../espnet/nets/pytorch_backend/transducer/custom_encoder.py
+129
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/error_calculator.py
...spnet/nets/pytorch_backend/transducer/error_calculator.py
+168
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/initializer.py
...lib/espnet/nets/pytorch_backend/transducer/initializer.py
+42
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/joint_network.py
...b/espnet/nets/pytorch_backend/transducer/joint_network.py
+73
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/rnn_decoder.py
...lib/espnet/nets/pytorch_backend/transducer/rnn_decoder.py
+295
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/rnn_encoder.py
...lib/espnet/nets/pytorch_backend/transducer/rnn_encoder.py
+572
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/transducer_tasks.py
...spnet/nets/pytorch_backend/transducer/transducer_tasks.py
+466
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/transformer_decoder_layer.py
...s/pytorch_backend/transducer/transformer_decoder_layer.py
+96
-0
No files found.
Too many changes to show.
To preserve performance only
216 of 216+
files are displayed.
Plain diff
Email patch
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/streaming/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/streaming/segment.py
0 → 100644
View file @
60a2c57a
import
numpy
as
np
import
torch
class
SegmentStreamingE2E
(
object
):
"""SegmentStreamingE2E constructor.
:param E2E e2e: E2E ASR object
:param recog_args: arguments for "recognize" method of E2E
"""
def
__init__
(
self
,
e2e
,
recog_args
,
rnnlm
=
None
):
self
.
_e2e
=
e2e
self
.
_recog_args
=
recog_args
self
.
_char_list
=
e2e
.
char_list
self
.
_rnnlm
=
rnnlm
self
.
_e2e
.
eval
()
self
.
_blank_idx_in_char_list
=
-
1
for
idx
in
range
(
len
(
self
.
_char_list
)):
if
self
.
_char_list
[
idx
]
==
self
.
_e2e
.
blank
:
self
.
_blank_idx_in_char_list
=
idx
break
self
.
_subsampling_factor
=
np
.
prod
(
e2e
.
subsample
)
self
.
_activates
=
0
self
.
_blank_dur
=
0
self
.
_previous_input
=
[]
self
.
_previous_encoder_recurrent_state
=
None
self
.
_encoder_states
=
[]
self
.
_ctc_posteriors
=
[]
assert
(
self
.
_recog_args
.
batchsize
<=
1
),
"SegmentStreamingE2E works only with batch size <= 1"
assert
(
"b"
not
in
self
.
_e2e
.
etype
),
"SegmentStreamingE2E works only with uni-directional encoders"
def
accept_input
(
self
,
x
):
"""Call this method each time a new batch of input is available."""
self
.
_previous_input
.
extend
(
x
)
h
,
ilen
=
self
.
_e2e
.
subsample_frames
(
x
)
# Run encoder and apply greedy search on CTC softmax output
h
,
_
,
self
.
_previous_encoder_recurrent_state
=
self
.
_e2e
.
enc
(
h
.
unsqueeze
(
0
),
ilen
,
self
.
_previous_encoder_recurrent_state
)
z
=
self
.
_e2e
.
ctc
.
argmax
(
h
).
squeeze
(
0
)
if
self
.
_activates
==
0
and
z
[
0
]
!=
self
.
_blank_idx_in_char_list
:
self
.
_activates
=
1
# Rerun encoder with zero state at onset of detection
tail_len
=
self
.
_subsampling_factor
*
(
self
.
_recog_args
.
streaming_onset_margin
+
1
)
h
,
ilen
=
self
.
_e2e
.
subsample_frames
(
np
.
reshape
(
self
.
_previous_input
[
-
tail_len
:],
[
-
1
,
len
(
self
.
_previous_input
[
0
])]
)
)
h
,
_
,
self
.
_previous_encoder_recurrent_state
=
self
.
_e2e
.
enc
(
h
.
unsqueeze
(
0
),
ilen
,
None
)
hyp
=
None
if
self
.
_activates
==
1
:
self
.
_encoder_states
.
extend
(
h
.
squeeze
(
0
))
self
.
_ctc_posteriors
.
extend
(
self
.
_e2e
.
ctc
.
log_softmax
(
h
).
squeeze
(
0
))
if
z
[
0
]
==
self
.
_blank_idx_in_char_list
:
self
.
_blank_dur
+=
1
else
:
self
.
_blank_dur
=
0
if
self
.
_blank_dur
>=
self
.
_recog_args
.
streaming_min_blank_dur
:
seg_len
=
(
len
(
self
.
_encoder_states
)
-
self
.
_blank_dur
+
self
.
_recog_args
.
streaming_offset_margin
)
if
seg_len
>
0
:
# Run decoder with a detected segment
h
=
torch
.
cat
(
self
.
_encoder_states
[:
seg_len
],
dim
=
0
).
view
(
-
1
,
self
.
_encoder_states
[
0
].
size
(
0
)
)
if
self
.
_recog_args
.
ctc_weight
>
0.0
:
lpz
=
torch
.
cat
(
self
.
_ctc_posteriors
[:
seg_len
],
dim
=
0
).
view
(
-
1
,
self
.
_ctc_posteriors
[
0
].
size
(
0
)
)
if
self
.
_recog_args
.
batchsize
>
0
:
lpz
=
lpz
.
unsqueeze
(
0
)
normalize_score
=
False
else
:
lpz
=
None
normalize_score
=
True
if
self
.
_recog_args
.
batchsize
==
0
:
hyp
=
self
.
_e2e
.
dec
.
recognize_beam
(
h
,
lpz
,
self
.
_recog_args
,
self
.
_char_list
,
self
.
_rnnlm
)
else
:
hlens
=
torch
.
tensor
([
h
.
shape
[
0
]])
hyp
=
self
.
_e2e
.
dec
.
recognize_beam_batch
(
h
.
unsqueeze
(
0
),
hlens
,
lpz
,
self
.
_recog_args
,
self
.
_char_list
,
self
.
_rnnlm
,
normalize_score
=
normalize_score
,
)[
0
]
self
.
_activates
=
0
self
.
_blank_dur
=
0
tail_len
=
(
self
.
_subsampling_factor
*
self
.
_recog_args
.
streaming_onset_margin
)
self
.
_previous_input
=
self
.
_previous_input
[
-
tail_len
:]
self
.
_encoder_states
=
[]
self
.
_ctc_posteriors
=
[]
return
hyp
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/streaming/window.py
0 → 100644
View file @
60a2c57a
import
torch
# TODO(pzelasko): Currently allows half-streaming only;
# needs streaming attention decoder implementation
class
WindowStreamingE2E
(
object
):
"""WindowStreamingE2E constructor.
:param E2E e2e: E2E ASR object
:param recog_args: arguments for "recognize" method of E2E
"""
def
__init__
(
self
,
e2e
,
recog_args
,
rnnlm
=
None
):
self
.
_e2e
=
e2e
self
.
_recog_args
=
recog_args
self
.
_char_list
=
e2e
.
char_list
self
.
_rnnlm
=
rnnlm
self
.
_e2e
.
eval
()
self
.
_offset
=
0
self
.
_previous_encoder_recurrent_state
=
None
self
.
_encoder_states
=
[]
self
.
_ctc_posteriors
=
[]
self
.
_last_recognition
=
None
assert
(
self
.
_recog_args
.
ctc_weight
>
0.0
),
"WindowStreamingE2E works only with combined CTC and attention decoders."
def
accept_input
(
self
,
x
):
"""Call this method each time a new batch of input is available."""
h
,
ilen
=
self
.
_e2e
.
subsample_frames
(
x
)
# Streaming encoder
h
,
_
,
self
.
_previous_encoder_recurrent_state
=
self
.
_e2e
.
enc
(
h
.
unsqueeze
(
0
),
ilen
,
self
.
_previous_encoder_recurrent_state
)
self
.
_encoder_states
.
append
(
h
.
squeeze
(
0
))
# CTC posteriors for the incoming audio
self
.
_ctc_posteriors
.
append
(
self
.
_e2e
.
ctc
.
log_softmax
(
h
).
squeeze
(
0
))
def
_input_window_for_decoder
(
self
,
use_all
=
False
):
if
use_all
:
return
(
torch
.
cat
(
self
.
_encoder_states
,
dim
=
0
),
torch
.
cat
(
self
.
_ctc_posteriors
,
dim
=
0
),
)
def
select_unprocessed_windows
(
window_tensors
):
last_offset
=
self
.
_offset
offset_traversed
=
0
selected_windows
=
[]
for
es
in
window_tensors
:
if
offset_traversed
>
last_offset
:
selected_windows
.
append
(
es
)
continue
offset_traversed
+=
es
.
size
(
1
)
return
torch
.
cat
(
selected_windows
,
dim
=
0
)
return
(
select_unprocessed_windows
(
self
.
_encoder_states
),
select_unprocessed_windows
(
self
.
_ctc_posteriors
),
)
def
decode_with_attention_offline
(
self
):
"""Run the attention decoder offline.
Works even if the previous layers (encoder and CTC decoder) were
being run in the online mode.
This method should be run after all the audio has been consumed.
This is used mostly to compare the results between offline
and online implementation of the previous layers.
"""
h
,
lpz
=
self
.
_input_window_for_decoder
(
use_all
=
True
)
return
self
.
_e2e
.
dec
.
recognize_beam
(
h
,
lpz
,
self
.
_recog_args
,
self
.
_char_list
,
self
.
_rnnlm
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/tacotron2/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/tacotron2/cbhg.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""CBHG related modules."""
import
torch
import
torch.nn.functional
as
F
from
torch.nn.utils.rnn
import
pack_padded_sequence
,
pad_packed_sequence
from
espnet.nets.pytorch_backend.nets_utils
import
make_non_pad_mask
class
CBHGLoss
(
torch
.
nn
.
Module
):
"""Loss function module for CBHG."""
def
__init__
(
self
,
use_masking
=
True
):
"""Initialize CBHG loss module.
Args:
use_masking (bool): Whether to mask padded part in loss calculation.
"""
super
(
CBHGLoss
,
self
).
__init__
()
self
.
use_masking
=
use_masking
def
forward
(
self
,
cbhg_outs
,
spcs
,
olens
):
"""Calculate forward propagation.
Args:
cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim).
spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim).
olens (LongTensor): Batch of the lengths of each sequence (B,).
Returns:
Tensor: L1 loss value
Tensor: Mean square error loss value.
"""
# perform masking for padded values
if
self
.
use_masking
:
mask
=
make_non_pad_mask
(
olens
).
unsqueeze
(
-
1
).
to
(
spcs
.
device
)
spcs
=
spcs
.
masked_select
(
mask
)
cbhg_outs
=
cbhg_outs
.
masked_select
(
mask
)
# calculate loss
cbhg_l1_loss
=
F
.
l1_loss
(
cbhg_outs
,
spcs
)
cbhg_mse_loss
=
F
.
mse_loss
(
cbhg_outs
,
spcs
)
return
cbhg_l1_loss
,
cbhg_mse_loss
class
CBHG
(
torch
.
nn
.
Module
):
"""CBHG module to convert log Mel-filterbanks to linear spectrogram.
This is a module of CBHG introduced
in `Tacotron: Towards End-to-End Speech Synthesis`_.
The CBHG converts the sequence of log Mel-filterbanks into linear spectrogram.
.. _`Tacotron: Towards End-to-End Speech Synthesis`:
https://arxiv.org/abs/1703.10135
"""
def
__init__
(
self
,
idim
,
odim
,
conv_bank_layers
=
8
,
conv_bank_chans
=
128
,
conv_proj_filts
=
3
,
conv_proj_chans
=
256
,
highway_layers
=
4
,
highway_units
=
128
,
gru_units
=
256
,
):
"""Initialize CBHG module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
conv_bank_layers (int, optional): The number of convolution bank layers.
conv_bank_chans (int, optional): The number of channels in convolution bank.
conv_proj_filts (int, optional):
Kernel size of convolutional projection layer.
conv_proj_chans (int, optional):
The number of channels in convolutional projection layer.
highway_layers (int, optional): The number of highway network layers.
highway_units (int, optional): The number of highway network units.
gru_units (int, optional): The number of GRU units (for both directions).
"""
super
(
CBHG
,
self
).
__init__
()
self
.
idim
=
idim
self
.
odim
=
odim
self
.
conv_bank_layers
=
conv_bank_layers
self
.
conv_bank_chans
=
conv_bank_chans
self
.
conv_proj_filts
=
conv_proj_filts
self
.
conv_proj_chans
=
conv_proj_chans
self
.
highway_layers
=
highway_layers
self
.
highway_units
=
highway_units
self
.
gru_units
=
gru_units
# define 1d convolution bank
self
.
conv_bank
=
torch
.
nn
.
ModuleList
()
for
k
in
range
(
1
,
self
.
conv_bank_layers
+
1
):
if
k
%
2
!=
0
:
padding
=
(
k
-
1
)
//
2
else
:
padding
=
((
k
-
1
)
//
2
,
(
k
-
1
)
//
2
+
1
)
self
.
conv_bank
+=
[
torch
.
nn
.
Sequential
(
torch
.
nn
.
ConstantPad1d
(
padding
,
0.0
),
torch
.
nn
.
Conv1d
(
idim
,
self
.
conv_bank_chans
,
k
,
stride
=
1
,
padding
=
0
,
bias
=
True
),
torch
.
nn
.
BatchNorm1d
(
self
.
conv_bank_chans
),
torch
.
nn
.
ReLU
(),
)
]
# define max pooling (need padding for one-side to keep same length)
self
.
max_pool
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
ConstantPad1d
((
0
,
1
),
0.0
),
torch
.
nn
.
MaxPool1d
(
2
,
stride
=
1
)
)
# define 1d convolution projection
self
.
projections
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv1d
(
self
.
conv_bank_chans
*
self
.
conv_bank_layers
,
self
.
conv_proj_chans
,
self
.
conv_proj_filts
,
stride
=
1
,
padding
=
(
self
.
conv_proj_filts
-
1
)
//
2
,
bias
=
True
,
),
torch
.
nn
.
BatchNorm1d
(
self
.
conv_proj_chans
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv1d
(
self
.
conv_proj_chans
,
self
.
idim
,
self
.
conv_proj_filts
,
stride
=
1
,
padding
=
(
self
.
conv_proj_filts
-
1
)
//
2
,
bias
=
True
,
),
torch
.
nn
.
BatchNorm1d
(
self
.
idim
),
)
# define highway network
self
.
highways
=
torch
.
nn
.
ModuleList
()
self
.
highways
+=
[
torch
.
nn
.
Linear
(
idim
,
self
.
highway_units
)]
for
_
in
range
(
self
.
highway_layers
):
self
.
highways
+=
[
HighwayNet
(
self
.
highway_units
)]
# define bidirectional GRU
self
.
gru
=
torch
.
nn
.
GRU
(
self
.
highway_units
,
gru_units
//
2
,
num_layers
=
1
,
batch_first
=
True
,
bidirectional
=
True
,
)
# define final projection
self
.
output
=
torch
.
nn
.
Linear
(
gru_units
,
odim
,
bias
=
True
)
def
forward
(
self
,
xs
,
ilens
):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequences of inputs (B, Tmax, idim).
ilens (LongTensor): Batch of lengths of each input sequence (B,).
Return:
Tensor: Batch of the padded sequence of outputs (B, Tmax, odim).
LongTensor: Batch of lengths of each output sequence (B,).
"""
xs
=
xs
.
transpose
(
1
,
2
)
# (B, idim, Tmax)
convs
=
[]
for
k
in
range
(
self
.
conv_bank_layers
):
convs
+=
[
self
.
conv_bank
[
k
](
xs
)]
convs
=
torch
.
cat
(
convs
,
dim
=
1
)
# (B, #CH * #BANK, Tmax)
convs
=
self
.
max_pool
(
convs
)
convs
=
self
.
projections
(
convs
).
transpose
(
1
,
2
)
# (B, Tmax, idim)
xs
=
xs
.
transpose
(
1
,
2
)
+
convs
# + 1 for dimension adjustment layer
for
i
in
range
(
self
.
highway_layers
+
1
):
xs
=
self
.
highways
[
i
](
xs
)
# sort by length
xs
,
ilens
,
sort_idx
=
self
.
_sort_by_length
(
xs
,
ilens
)
# total_length needs for DataParallel
# (see https://github.com/pytorch/pytorch/pull/6327)
total_length
=
xs
.
size
(
1
)
if
not
isinstance
(
ilens
,
torch
.
Tensor
):
ilens
=
torch
.
tensor
(
ilens
)
xs
=
pack_padded_sequence
(
xs
,
ilens
.
cpu
(),
batch_first
=
True
)
self
.
gru
.
flatten_parameters
()
xs
,
_
=
self
.
gru
(
xs
)
xs
,
ilens
=
pad_packed_sequence
(
xs
,
batch_first
=
True
,
total_length
=
total_length
)
# revert sorting by length
xs
,
ilens
=
self
.
_revert_sort_by_length
(
xs
,
ilens
,
sort_idx
)
xs
=
self
.
output
(
xs
)
# (B, Tmax, odim)
return
xs
,
ilens
def
inference
(
self
,
x
):
"""Inference.
Args:
x (Tensor): The sequences of inputs (T, idim).
Return:
Tensor: The sequence of outputs (T, odim).
"""
assert
len
(
x
.
size
())
==
2
xs
=
x
.
unsqueeze
(
0
)
ilens
=
x
.
new
([
x
.
size
(
0
)]).
long
()
return
self
.
forward
(
xs
,
ilens
)[
0
][
0
]
def
_sort_by_length
(
self
,
xs
,
ilens
):
sort_ilens
,
sort_idx
=
ilens
.
sort
(
0
,
descending
=
True
)
return
xs
[
sort_idx
],
ilens
[
sort_idx
],
sort_idx
def
_revert_sort_by_length
(
self
,
xs
,
ilens
,
sort_idx
):
_
,
revert_idx
=
sort_idx
.
sort
(
0
)
return
xs
[
revert_idx
],
ilens
[
revert_idx
]
class
HighwayNet
(
torch
.
nn
.
Module
):
"""Highway Network module.
This is a module of Highway Network introduced in `Highway Networks`_.
.. _`Highway Networks`: https://arxiv.org/abs/1505.00387
"""
def
__init__
(
self
,
idim
):
"""Initialize Highway Network module.
Args:
idim (int): Dimension of the inputs.
"""
super
(
HighwayNet
,
self
).
__init__
()
self
.
idim
=
idim
self
.
projection
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
idim
,
idim
),
torch
.
nn
.
ReLU
()
)
self
.
gate
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
idim
,
idim
),
torch
.
nn
.
Sigmoid
())
def
forward
(
self
,
x
):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of inputs (B, ..., idim).
Returns:
Tensor: Batch of outputs, which are the same shape as inputs (B, ..., idim).
"""
proj
=
self
.
projection
(
x
)
gate
=
self
.
gate
(
x
)
return
proj
*
gate
+
x
*
(
1.0
-
gate
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/tacotron2/decoder.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Tacotron2 decoder related modules."""
import
torch
import
torch.nn.functional
as
F
from
espnet.nets.pytorch_backend.rnn.attentions
import
AttForwardTA
def
decoder_init
(
m
):
"""Initialize decoder parameters."""
if
isinstance
(
m
,
torch
.
nn
.
Conv1d
):
torch
.
nn
.
init
.
xavier_uniform_
(
m
.
weight
,
torch
.
nn
.
init
.
calculate_gain
(
"tanh"
))
class
ZoneOutCell
(
torch
.
nn
.
Module
):
"""ZoneOut Cell module.
This is a module of zoneout described in
`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`_.
This code is modified from `eladhoffer/seq2seq.pytorch`_.
Examples:
>>> lstm = torch.nn.LSTMCell(16, 32)
>>> lstm = ZoneOutCell(lstm, 0.5)
.. _`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`:
https://arxiv.org/abs/1606.01305
.. _`eladhoffer/seq2seq.pytorch`:
https://github.com/eladhoffer/seq2seq.pytorch
"""
def
__init__
(
self
,
cell
,
zoneout_rate
=
0.1
):
"""Initialize zone out cell module.
Args:
cell (torch.nn.Module): Pytorch recurrent cell module
e.g. `torch.nn.Module.LSTMCell`.
zoneout_rate (float, optional): Probability of zoneout from 0.0 to 1.0.
"""
super
(
ZoneOutCell
,
self
).
__init__
()
self
.
cell
=
cell
self
.
hidden_size
=
cell
.
hidden_size
self
.
zoneout_rate
=
zoneout_rate
if
zoneout_rate
>
1.0
or
zoneout_rate
<
0.0
:
raise
ValueError
(
"zoneout probability must be in the range from 0.0 to 1.0."
)
def
forward
(
self
,
inputs
,
hidden
):
"""Calculate forward propagation.
Args:
inputs (Tensor): Batch of input tensor (B, input_size).
hidden (tuple):
- Tensor: Batch of initial hidden states (B, hidden_size).
- Tensor: Batch of initial cell states (B, hidden_size).
Returns:
tuple:
- Tensor: Batch of next hidden states (B, hidden_size).
- Tensor: Batch of next cell states (B, hidden_size).
"""
next_hidden
=
self
.
cell
(
inputs
,
hidden
)
next_hidden
=
self
.
_zoneout
(
hidden
,
next_hidden
,
self
.
zoneout_rate
)
return
next_hidden
def
_zoneout
(
self
,
h
,
next_h
,
prob
):
# apply recursively
if
isinstance
(
h
,
tuple
):
num_h
=
len
(
h
)
if
not
isinstance
(
prob
,
tuple
):
prob
=
tuple
([
prob
]
*
num_h
)
return
tuple
(
[
self
.
_zoneout
(
h
[
i
],
next_h
[
i
],
prob
[
i
])
for
i
in
range
(
num_h
)]
)
if
self
.
training
:
mask
=
h
.
new
(
*
h
.
size
()).
bernoulli_
(
prob
)
return
mask
*
h
+
(
1
-
mask
)
*
next_h
else
:
return
prob
*
h
+
(
1
-
prob
)
*
next_h
class
Prenet
(
torch
.
nn
.
Module
):
"""Prenet module for decoder of Spectrogram prediction network.
This is a module of Prenet in the decoder of Spectrogram prediction network,
which described in `Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_.
The Prenet preforms nonlinear conversion
of inputs before input to auto-regressive lstm,
which helps to learn diagonal attentions.
Note:
This module alway applies dropout even in evaluation.
See the detail in `Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions`_.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def
__init__
(
self
,
idim
,
n_layers
=
2
,
n_units
=
256
,
dropout_rate
=
0.5
):
"""Initialize prenet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of prenet layers.
n_units (int, optional): The number of prenet units.
"""
super
(
Prenet
,
self
).
__init__
()
self
.
dropout_rate
=
dropout_rate
self
.
prenet
=
torch
.
nn
.
ModuleList
()
for
layer
in
range
(
n_layers
):
n_inputs
=
idim
if
layer
==
0
else
n_units
self
.
prenet
+=
[
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
n_inputs
,
n_units
),
torch
.
nn
.
ReLU
())
]
def
forward
(
self
,
x
):
"""Calculate forward propagation.
Args:
x (Tensor): Batch of input tensors (B, ..., idim).
Returns:
Tensor: Batch of output tensors (B, ..., odim).
"""
for
i
in
range
(
len
(
self
.
prenet
)):
# we make this part non deterministic. See the above note.
x
=
F
.
dropout
(
self
.
prenet
[
i
](
x
),
self
.
dropout_rate
)
return
x
class
Postnet
(
torch
.
nn
.
Module
):
"""Postnet module for Spectrogram prediction network.
This is a module of Postnet in Spectrogram prediction network,
which described in `Natural TTS Synthesis by
Conditioning WaveNet on Mel Spectrogram Predictions`_.
The Postnet predicts refines the predicted
Mel-filterbank of the decoder,
which helps to compensate the detail structure of spectrogram.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def
__init__
(
self
,
idim
,
odim
,
n_layers
=
5
,
n_chans
=
512
,
n_filts
=
5
,
dropout_rate
=
0.5
,
use_batch_norm
=
True
,
):
"""Initialize postnet module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
n_layers (int, optional): The number of layers.
n_filts (int, optional): The number of filter size.
n_units (int, optional): The number of filter channels.
use_batch_norm (bool, optional): Whether to use batch normalization..
dropout_rate (float, optional): Dropout rate..
"""
super
(
Postnet
,
self
).
__init__
()
self
.
postnet
=
torch
.
nn
.
ModuleList
()
for
layer
in
range
(
n_layers
-
1
):
ichans
=
odim
if
layer
==
0
else
n_chans
ochans
=
odim
if
layer
==
n_layers
-
1
else
n_chans
if
use_batch_norm
:
self
.
postnet
+=
[
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv1d
(
ichans
,
ochans
,
n_filts
,
stride
=
1
,
padding
=
(
n_filts
-
1
)
//
2
,
bias
=
False
,
),
torch
.
nn
.
BatchNorm1d
(
ochans
),
torch
.
nn
.
Tanh
(),
torch
.
nn
.
Dropout
(
dropout_rate
),
)
]
else
:
self
.
postnet
+=
[
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv1d
(
ichans
,
ochans
,
n_filts
,
stride
=
1
,
padding
=
(
n_filts
-
1
)
//
2
,
bias
=
False
,
),
torch
.
nn
.
Tanh
(),
torch
.
nn
.
Dropout
(
dropout_rate
),
)
]
ichans
=
n_chans
if
n_layers
!=
1
else
odim
if
use_batch_norm
:
self
.
postnet
+=
[
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv1d
(
ichans
,
odim
,
n_filts
,
stride
=
1
,
padding
=
(
n_filts
-
1
)
//
2
,
bias
=
False
,
),
torch
.
nn
.
BatchNorm1d
(
odim
),
torch
.
nn
.
Dropout
(
dropout_rate
),
)
]
else
:
self
.
postnet
+=
[
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv1d
(
ichans
,
odim
,
n_filts
,
stride
=
1
,
padding
=
(
n_filts
-
1
)
//
2
,
bias
=
False
,
),
torch
.
nn
.
Dropout
(
dropout_rate
),
)
]
def
forward
(
self
,
xs
):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax).
Returns:
Tensor: Batch of padded output tensor. (B, odim, Tmax).
"""
for
i
in
range
(
len
(
self
.
postnet
)):
xs
=
self
.
postnet
[
i
](
xs
)
return
xs
class
Decoder
(
torch
.
nn
.
Module
):
"""Decoder module of Spectrogram prediction network.
This is a module of decoder of Spectrogram prediction network in Tacotron2,
which described in `Natural TTS
Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_.
The decoder generates the sequence of
features from the sequence of the hidden states.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def
__init__
(
self
,
idim
,
odim
,
att
,
dlayers
=
2
,
dunits
=
1024
,
prenet_layers
=
2
,
prenet_units
=
256
,
postnet_layers
=
5
,
postnet_chans
=
512
,
postnet_filts
=
5
,
output_activation_fn
=
None
,
cumulate_att_w
=
True
,
use_batch_norm
=
True
,
use_concate
=
True
,
dropout_rate
=
0.5
,
zoneout_rate
=
0.1
,
reduction_factor
=
1
,
):
"""Initialize Tacotron2 decoder module.
Args:
idim (int): Dimension of the inputs.
odim (int): Dimension of the outputs.
att (torch.nn.Module): Instance of attention class.
dlayers (int, optional): The number of decoder lstm layers.
dunits (int, optional): The number of decoder lstm units.
prenet_layers (int, optional): The number of prenet layers.
prenet_units (int, optional): The number of prenet units.
postnet_layers (int, optional): The number of postnet layers.
postnet_filts (int, optional): The number of postnet filter size.
postnet_chans (int, optional): The number of postnet filter channels.
output_activation_fn (torch.nn.Module, optional):
Activation function for outputs.
cumulate_att_w (bool, optional):
Whether to cumulate previous attention weight.
use_batch_norm (bool, optional): Whether to use batch normalization.
use_concate (bool, optional): Whether to concatenate encoder embedding
with decoder lstm outputs.
dropout_rate (float, optional): Dropout rate.
zoneout_rate (float, optional): Zoneout rate.
reduction_factor (int, optional): Reduction factor.
"""
super
(
Decoder
,
self
).
__init__
()
# store the hyperparameters
self
.
idim
=
idim
self
.
odim
=
odim
self
.
att
=
att
self
.
output_activation_fn
=
output_activation_fn
self
.
cumulate_att_w
=
cumulate_att_w
self
.
use_concate
=
use_concate
self
.
reduction_factor
=
reduction_factor
# check attention type
if
isinstance
(
self
.
att
,
AttForwardTA
):
self
.
use_att_extra_inputs
=
True
else
:
self
.
use_att_extra_inputs
=
False
# define lstm network
prenet_units
=
prenet_units
if
prenet_layers
!=
0
else
odim
self
.
lstm
=
torch
.
nn
.
ModuleList
()
for
layer
in
range
(
dlayers
):
iunits
=
idim
+
prenet_units
if
layer
==
0
else
dunits
lstm
=
torch
.
nn
.
LSTMCell
(
iunits
,
dunits
)
if
zoneout_rate
>
0.0
:
lstm
=
ZoneOutCell
(
lstm
,
zoneout_rate
)
self
.
lstm
+=
[
lstm
]
# define prenet
if
prenet_layers
>
0
:
self
.
prenet
=
Prenet
(
idim
=
odim
,
n_layers
=
prenet_layers
,
n_units
=
prenet_units
,
dropout_rate
=
dropout_rate
,
)
else
:
self
.
prenet
=
None
# define postnet
if
postnet_layers
>
0
:
self
.
postnet
=
Postnet
(
idim
=
idim
,
odim
=
odim
,
n_layers
=
postnet_layers
,
n_chans
=
postnet_chans
,
n_filts
=
postnet_filts
,
use_batch_norm
=
use_batch_norm
,
dropout_rate
=
dropout_rate
,
)
else
:
self
.
postnet
=
None
# define projection layers
iunits
=
idim
+
dunits
if
use_concate
else
dunits
self
.
feat_out
=
torch
.
nn
.
Linear
(
iunits
,
odim
*
reduction_factor
,
bias
=
False
)
self
.
prob_out
=
torch
.
nn
.
Linear
(
iunits
,
reduction_factor
)
# initialize
self
.
apply
(
decoder_init
)
def
_zero_state
(
self
,
hs
):
init_hs
=
hs
.
new_zeros
(
hs
.
size
(
0
),
self
.
lstm
[
0
].
hidden_size
)
return
init_hs
def
forward
(
self
,
hs
,
hlens
,
ys
):
"""Calculate forward propagation.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor):
Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
Tensor: Batch of output tensors after postnet (B, Lmax, odim).
Tensor: Batch of output tensors before postnet (B, Lmax, odim).
Tensor: Batch of logits of stop prediction (B, Lmax).
Tensor: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if
self
.
reduction_factor
>
1
:
ys
=
ys
[:,
self
.
reduction_factor
-
1
::
self
.
reduction_factor
]
# length list should be list of int
hlens
=
list
(
map
(
int
,
hlens
))
# initialize hidden states of decoder
c_list
=
[
self
.
_zero_state
(
hs
)]
z_list
=
[
self
.
_zero_state
(
hs
)]
for
_
in
range
(
1
,
len
(
self
.
lstm
)):
c_list
+=
[
self
.
_zero_state
(
hs
)]
z_list
+=
[
self
.
_zero_state
(
hs
)]
prev_out
=
hs
.
new_zeros
(
hs
.
size
(
0
),
self
.
odim
)
# initialize attention
prev_att_w
=
None
self
.
att
.
reset
()
# loop for an output sequence
outs
,
logits
,
att_ws
=
[],
[],
[]
for
y
in
ys
.
transpose
(
0
,
1
):
if
self
.
use_att_extra_inputs
:
att_c
,
att_w
=
self
.
att
(
hs
,
hlens
,
z_list
[
0
],
prev_att_w
,
prev_out
)
else
:
att_c
,
att_w
=
self
.
att
(
hs
,
hlens
,
z_list
[
0
],
prev_att_w
)
prenet_out
=
self
.
prenet
(
prev_out
)
if
self
.
prenet
is
not
None
else
prev_out
xs
=
torch
.
cat
([
att_c
,
prenet_out
],
dim
=
1
)
z_list
[
0
],
c_list
[
0
]
=
self
.
lstm
[
0
](
xs
,
(
z_list
[
0
],
c_list
[
0
]))
for
i
in
range
(
1
,
len
(
self
.
lstm
)):
z_list
[
i
],
c_list
[
i
]
=
self
.
lstm
[
i
](
z_list
[
i
-
1
],
(
z_list
[
i
],
c_list
[
i
])
)
zcs
=
(
torch
.
cat
([
z_list
[
-
1
],
att_c
],
dim
=
1
)
if
self
.
use_concate
else
z_list
[
-
1
]
)
outs
+=
[
self
.
feat_out
(
zcs
).
view
(
hs
.
size
(
0
),
self
.
odim
,
-
1
)]
logits
+=
[
self
.
prob_out
(
zcs
)]
att_ws
+=
[
att_w
]
prev_out
=
y
# teacher forcing
if
self
.
cumulate_att_w
and
prev_att_w
is
not
None
:
prev_att_w
=
prev_att_w
+
att_w
# Note: error when use +=
else
:
prev_att_w
=
att_w
logits
=
torch
.
cat
(
logits
,
dim
=
1
)
# (B, Lmax)
before_outs
=
torch
.
cat
(
outs
,
dim
=
2
)
# (B, odim, Lmax)
att_ws
=
torch
.
stack
(
att_ws
,
dim
=
1
)
# (B, Lmax, Tmax)
if
self
.
reduction_factor
>
1
:
before_outs
=
before_outs
.
view
(
before_outs
.
size
(
0
),
self
.
odim
,
-
1
)
# (B, odim, Lmax)
if
self
.
postnet
is
not
None
:
after_outs
=
before_outs
+
self
.
postnet
(
before_outs
)
# (B, odim, Lmax)
else
:
after_outs
=
before_outs
before_outs
=
before_outs
.
transpose
(
2
,
1
)
# (B, Lmax, odim)
after_outs
=
after_outs
.
transpose
(
2
,
1
)
# (B, Lmax, odim)
logits
=
logits
# apply activation function for scaling
if
self
.
output_activation_fn
is
not
None
:
before_outs
=
self
.
output_activation_fn
(
before_outs
)
after_outs
=
self
.
output_activation_fn
(
after_outs
)
return
after_outs
,
before_outs
,
logits
,
att_ws
def
inference
(
self
,
h
,
threshold
=
0.5
,
minlenratio
=
0.0
,
maxlenratio
=
10.0
,
use_att_constraint
=
False
,
backward_window
=
None
,
forward_window
=
None
,
):
"""Generate the sequence of features given the sequences of characters.
Args:
h (Tensor): Input sequence of encoder hidden states (T, C).
threshold (float, optional): Threshold to stop generation.
minlenratio (float, optional): Minimum length ratio.
If set to 1.0 and the length of input is 10,
the minimum length of outputs will be 10 * 1 = 10.
minlenratio (float, optional): Minimum length ratio.
If set to 10 and the length of input is 10,
the maximum length of outputs will be 10 * 10 = 100.
use_att_constraint (bool):
Whether to apply attention constraint introduced in `Deep Voice 3`_.
backward_window (int): Backward window size in attention constraint.
forward_window (int): Forward window size in attention constraint.
Returns:
Tensor: Output sequence of features (L, odim).
Tensor: Output sequence of stop probabilities (L,).
Tensor: Attention weights (L, T).
Note:
This computation is performed in auto-regressive manner.
.. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654
"""
# setup
assert
len
(
h
.
size
())
==
2
hs
=
h
.
unsqueeze
(
0
)
ilens
=
[
h
.
size
(
0
)]
maxlen
=
int
(
h
.
size
(
0
)
*
maxlenratio
)
minlen
=
int
(
h
.
size
(
0
)
*
minlenratio
)
# initialize hidden states of decoder
c_list
=
[
self
.
_zero_state
(
hs
)]
z_list
=
[
self
.
_zero_state
(
hs
)]
for
_
in
range
(
1
,
len
(
self
.
lstm
)):
c_list
+=
[
self
.
_zero_state
(
hs
)]
z_list
+=
[
self
.
_zero_state
(
hs
)]
prev_out
=
hs
.
new_zeros
(
1
,
self
.
odim
)
# initialize attention
prev_att_w
=
None
self
.
att
.
reset
()
# setup for attention constraint
if
use_att_constraint
:
last_attended_idx
=
0
else
:
last_attended_idx
=
None
# loop for an output sequence
idx
=
0
outs
,
att_ws
,
probs
=
[],
[],
[]
while
True
:
# updated index
idx
+=
self
.
reduction_factor
# decoder calculation
if
self
.
use_att_extra_inputs
:
att_c
,
att_w
=
self
.
att
(
hs
,
ilens
,
z_list
[
0
],
prev_att_w
,
prev_out
,
last_attended_idx
=
last_attended_idx
,
backward_window
=
backward_window
,
forward_window
=
forward_window
,
)
else
:
att_c
,
att_w
=
self
.
att
(
hs
,
ilens
,
z_list
[
0
],
prev_att_w
,
last_attended_idx
=
last_attended_idx
,
backward_window
=
backward_window
,
forward_window
=
forward_window
,
)
att_ws
+=
[
att_w
]
prenet_out
=
self
.
prenet
(
prev_out
)
if
self
.
prenet
is
not
None
else
prev_out
xs
=
torch
.
cat
([
att_c
,
prenet_out
],
dim
=
1
)
z_list
[
0
],
c_list
[
0
]
=
self
.
lstm
[
0
](
xs
,
(
z_list
[
0
],
c_list
[
0
]))
for
i
in
range
(
1
,
len
(
self
.
lstm
)):
z_list
[
i
],
c_list
[
i
]
=
self
.
lstm
[
i
](
z_list
[
i
-
1
],
(
z_list
[
i
],
c_list
[
i
])
)
zcs
=
(
torch
.
cat
([
z_list
[
-
1
],
att_c
],
dim
=
1
)
if
self
.
use_concate
else
z_list
[
-
1
]
)
outs
+=
[
self
.
feat_out
(
zcs
).
view
(
1
,
self
.
odim
,
-
1
)]
# [(1, odim, r), ...]
probs
+=
[
torch
.
sigmoid
(
self
.
prob_out
(
zcs
))[
0
]]
# [(r), ...]
if
self
.
output_activation_fn
is
not
None
:
prev_out
=
self
.
output_activation_fn
(
outs
[
-
1
][:,
:,
-
1
])
# (1, odim)
else
:
prev_out
=
outs
[
-
1
][:,
:,
-
1
]
# (1, odim)
if
self
.
cumulate_att_w
and
prev_att_w
is
not
None
:
prev_att_w
=
prev_att_w
+
att_w
# Note: error when use +=
else
:
prev_att_w
=
att_w
if
use_att_constraint
:
last_attended_idx
=
int
(
att_w
.
argmax
())
# check whether to finish generation
if
int
(
sum
(
probs
[
-
1
]
>=
threshold
))
>
0
or
idx
>=
maxlen
:
# check mininum length
if
idx
<
minlen
:
continue
outs
=
torch
.
cat
(
outs
,
dim
=
2
)
# (1, odim, L)
if
self
.
postnet
is
not
None
:
outs
=
outs
+
self
.
postnet
(
outs
)
# (1, odim, L)
outs
=
outs
.
transpose
(
2
,
1
).
squeeze
(
0
)
# (L, odim)
probs
=
torch
.
cat
(
probs
,
dim
=
0
)
att_ws
=
torch
.
cat
(
att_ws
,
dim
=
0
)
break
if
self
.
output_activation_fn
is
not
None
:
outs
=
self
.
output_activation_fn
(
outs
)
return
outs
,
probs
,
att_ws
def
calculate_all_attentions
(
self
,
hs
,
hlens
,
ys
):
"""Calculate all of the attention weights.
Args:
hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim).
hlens (LongTensor): Batch of lengths of each input batch (B,).
ys (Tensor):
Batch of the sequences of padded target features (B, Lmax, odim).
Returns:
numpy.ndarray: Batch of attention weights (B, Lmax, Tmax).
Note:
This computation is performed in teacher-forcing manner.
"""
# thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim)
if
self
.
reduction_factor
>
1
:
ys
=
ys
[:,
self
.
reduction_factor
-
1
::
self
.
reduction_factor
]
# length list should be list of int
hlens
=
list
(
map
(
int
,
hlens
))
# initialize hidden states of decoder
c_list
=
[
self
.
_zero_state
(
hs
)]
z_list
=
[
self
.
_zero_state
(
hs
)]
for
_
in
range
(
1
,
len
(
self
.
lstm
)):
c_list
+=
[
self
.
_zero_state
(
hs
)]
z_list
+=
[
self
.
_zero_state
(
hs
)]
prev_out
=
hs
.
new_zeros
(
hs
.
size
(
0
),
self
.
odim
)
# initialize attention
prev_att_w
=
None
self
.
att
.
reset
()
# loop for an output sequence
att_ws
=
[]
for
y
in
ys
.
transpose
(
0
,
1
):
if
self
.
use_att_extra_inputs
:
att_c
,
att_w
=
self
.
att
(
hs
,
hlens
,
z_list
[
0
],
prev_att_w
,
prev_out
)
else
:
att_c
,
att_w
=
self
.
att
(
hs
,
hlens
,
z_list
[
0
],
prev_att_w
)
att_ws
+=
[
att_w
]
prenet_out
=
self
.
prenet
(
prev_out
)
if
self
.
prenet
is
not
None
else
prev_out
xs
=
torch
.
cat
([
att_c
,
prenet_out
],
dim
=
1
)
z_list
[
0
],
c_list
[
0
]
=
self
.
lstm
[
0
](
xs
,
(
z_list
[
0
],
c_list
[
0
]))
for
i
in
range
(
1
,
len
(
self
.
lstm
)):
z_list
[
i
],
c_list
[
i
]
=
self
.
lstm
[
i
](
z_list
[
i
-
1
],
(
z_list
[
i
],
c_list
[
i
])
)
prev_out
=
y
# teacher forcing
if
self
.
cumulate_att_w
and
prev_att_w
is
not
None
:
prev_att_w
=
prev_att_w
+
att_w
# Note: error when use +=
else
:
prev_att_w
=
att_w
att_ws
=
torch
.
stack
(
att_ws
,
dim
=
1
)
# (B, Lmax, Tmax)
return
att_ws
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/tacotron2/encoder.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Tacotron2 encoder related modules."""
import
torch
from
torch.nn.utils.rnn
import
pack_padded_sequence
,
pad_packed_sequence
def
encoder_init
(
m
):
"""Initialize encoder parameters."""
if
isinstance
(
m
,
torch
.
nn
.
Conv1d
):
torch
.
nn
.
init
.
xavier_uniform_
(
m
.
weight
,
torch
.
nn
.
init
.
calculate_gain
(
"relu"
))
class
Encoder
(
torch
.
nn
.
Module
):
"""Encoder module of Spectrogram prediction network.
This is a module of encoder of Spectrogram prediction network in Tacotron2,
which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel
Spectrogram Predictions`_. This is the encoder which converts either a sequence
of characters or acoustic features into the sequence of hidden states.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
"""
def
__init__
(
self
,
idim
,
input_layer
=
"embed"
,
embed_dim
=
512
,
elayers
=
1
,
eunits
=
512
,
econv_layers
=
3
,
econv_chans
=
512
,
econv_filts
=
5
,
use_batch_norm
=
True
,
use_residual
=
False
,
dropout_rate
=
0.5
,
padding_idx
=
0
,
):
"""Initialize Tacotron2 encoder module.
Args:
idim (int) Dimension of the inputs.
input_layer (str): Input layer type.
embed_dim (int, optional) Dimension of character embedding.
elayers (int, optional) The number of encoder blstm layers.
eunits (int, optional) The number of encoder blstm units.
econv_layers (int, optional) The number of encoder conv layers.
econv_filts (int, optional) The number of encoder conv filter size.
econv_chans (int, optional) The number of encoder conv filter channels.
use_batch_norm (bool, optional) Whether to use batch normalization.
use_residual (bool, optional) Whether to use residual connection.
dropout_rate (float, optional) Dropout rate.
"""
super
(
Encoder
,
self
).
__init__
()
# store the hyperparameters
self
.
idim
=
idim
self
.
use_residual
=
use_residual
# define network layer modules
if
input_layer
==
"linear"
:
self
.
embed
=
torch
.
nn
.
Linear
(
idim
,
econv_chans
)
elif
input_layer
==
"embed"
:
self
.
embed
=
torch
.
nn
.
Embedding
(
idim
,
embed_dim
,
padding_idx
=
padding_idx
)
else
:
raise
ValueError
(
"unknown input_layer: "
+
input_layer
)
if
econv_layers
>
0
:
self
.
convs
=
torch
.
nn
.
ModuleList
()
for
layer
in
range
(
econv_layers
):
ichans
=
(
embed_dim
if
layer
==
0
and
input_layer
==
"embed"
else
econv_chans
)
if
use_batch_norm
:
self
.
convs
+=
[
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv1d
(
ichans
,
econv_chans
,
econv_filts
,
stride
=
1
,
padding
=
(
econv_filts
-
1
)
//
2
,
bias
=
False
,
),
torch
.
nn
.
BatchNorm1d
(
econv_chans
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Dropout
(
dropout_rate
),
)
]
else
:
self
.
convs
+=
[
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv1d
(
ichans
,
econv_chans
,
econv_filts
,
stride
=
1
,
padding
=
(
econv_filts
-
1
)
//
2
,
bias
=
False
,
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Dropout
(
dropout_rate
),
)
]
else
:
self
.
convs
=
None
if
elayers
>
0
:
iunits
=
econv_chans
if
econv_layers
!=
0
else
embed_dim
self
.
blstm
=
torch
.
nn
.
LSTM
(
iunits
,
eunits
//
2
,
elayers
,
batch_first
=
True
,
bidirectional
=
True
)
else
:
self
.
blstm
=
None
# initialize
self
.
apply
(
encoder_init
)
def
forward
(
self
,
xs
,
ilens
=
None
):
"""Calculate forward propagation.
Args:
xs (Tensor): Batch of the padded sequence. Either character ids (B, Tmax)
or acoustic feature (B, Tmax, idim * encoder_reduction_factor). Padded
value should be 0.
ilens (LongTensor): Batch of lengths of each input batch (B,).
Returns:
Tensor: Batch of the sequences of encoder states(B, Tmax, eunits).
LongTensor: Batch of lengths of each sequence (B,)
"""
xs
=
self
.
embed
(
xs
).
transpose
(
1
,
2
)
if
self
.
convs
is
not
None
:
for
i
in
range
(
len
(
self
.
convs
)):
if
self
.
use_residual
:
xs
=
xs
+
self
.
convs
[
i
](
xs
)
else
:
xs
=
self
.
convs
[
i
](
xs
)
if
self
.
blstm
is
None
:
return
xs
.
transpose
(
1
,
2
)
if
not
isinstance
(
ilens
,
torch
.
Tensor
):
ilens
=
torch
.
tensor
(
ilens
)
xs
=
pack_padded_sequence
(
xs
.
transpose
(
1
,
2
),
ilens
.
cpu
(),
batch_first
=
True
)
self
.
blstm
.
flatten_parameters
()
xs
,
_
=
self
.
blstm
(
xs
)
# (B, Tmax, C)
xs
,
hlens
=
pad_packed_sequence
(
xs
,
batch_first
=
True
)
return
xs
,
hlens
def
inference
(
self
,
x
):
"""Inference.
Args:
x (Tensor): The sequeunce of character ids (T,)
or acoustic feature (T, idim * encoder_reduction_factor).
Returns:
Tensor: The sequences of encoder states(T, eunits).
"""
xs
=
x
.
unsqueeze
(
0
)
ilens
=
torch
.
tensor
([
x
.
size
(
0
)])
return
self
.
forward
(
xs
,
ilens
)[
0
][
0
]
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/arguments.py
0 → 100644
View file @
60a2c57a
"""Transducer model arguments."""
import
ast
from
argparse
import
_ArgumentGroup
from
distutils.util
import
strtobool
def
add_encoder_general_arguments
(
group
:
_ArgumentGroup
)
->
_ArgumentGroup
:
"""Define general arguments for encoder."""
group
.
add_argument
(
"--etype"
,
default
=
"blstmp"
,
type
=
str
,
choices
=
[
"custom"
,
"lstm"
,
"blstm"
,
"lstmp"
,
"blstmp"
,
"vgglstmp"
,
"vggblstmp"
,
"vgglstm"
,
"vggblstm"
,
"gru"
,
"bgru"
,
"grup"
,
"bgrup"
,
"vgggrup"
,
"vggbgrup"
,
"vgggru"
,
"vggbgru"
,
],
help
=
"Type of encoder network architecture"
,
)
group
.
add_argument
(
"--dropout-rate"
,
default
=
0.0
,
type
=
float
,
help
=
"Dropout rate for the encoder"
,
)
return
group
def
add_rnn_encoder_arguments
(
group
:
_ArgumentGroup
)
->
_ArgumentGroup
:
"""Define arguments for RNN encoder."""
group
.
add_argument
(
"--elayers"
,
default
=
4
,
type
=
int
,
help
=
"Number of encoder layers (for shared recognition part "
"in multi-speaker asr mode)"
,
)
group
.
add_argument
(
"--eunits"
,
"-u"
,
default
=
300
,
type
=
int
,
help
=
"Number of encoder hidden units"
,
)
group
.
add_argument
(
"--eprojs"
,
default
=
320
,
type
=
int
,
help
=
"Number of encoder projection units"
)
group
.
add_argument
(
"--subsample"
,
default
=
"1"
,
type
=
str
,
help
=
"Subsample input frames x_y_z means subsample every x frame "
"at 1st layer, every y frame at 2nd layer etc."
,
)
return
group
def
add_custom_encoder_arguments
(
group
:
_ArgumentGroup
)
->
_ArgumentGroup
:
"""Define arguments for Custom encoder."""
group
.
add_argument
(
"--enc-block-arch"
,
type
=
eval
,
action
=
"append"
,
default
=
None
,
help
=
"Encoder architecture definition by blocks"
,
)
group
.
add_argument
(
"--enc-block-repeat"
,
default
=
1
,
type
=
int
,
help
=
"Repeat N times the provided encoder blocks if N > 1"
,
)
group
.
add_argument
(
"--custom-enc-input-layer"
,
type
=
str
,
default
=
"conv2d"
,
choices
=
[
"conv2d"
,
"vgg2l"
,
"linear"
,
"embed"
],
help
=
"Custom encoder input layer type"
,
)
group
.
add_argument
(
"--custom-enc-input-dropout-rate"
,
type
=
float
,
default
=
0.0
,
help
=
"Dropout rate of custom encoder input layer"
,
)
group
.
add_argument
(
"--custom-enc-input-pos-enc-dropout-rate"
,
type
=
float
,
default
=
0.0
,
help
=
"Dropout rate of positional encoding in custom encoder input layer"
,
)
group
.
add_argument
(
"--custom-enc-positional-encoding-type"
,
type
=
str
,
default
=
"abs_pos"
,
choices
=
[
"abs_pos"
,
"scaled_abs_pos"
,
"rel_pos"
],
help
=
"Custom encoder positional encoding layer type"
,
)
group
.
add_argument
(
"--custom-enc-self-attn-type"
,
type
=
str
,
default
=
"self_attn"
,
choices
=
[
"self_attn"
,
"rel_self_attn"
],
help
=
"Custom encoder self-attention type"
,
)
group
.
add_argument
(
"--custom-enc-pw-activation-type"
,
type
=
str
,
default
=
"relu"
,
choices
=
[
"relu"
,
"hardtanh"
,
"selu"
,
"swish"
],
help
=
"Custom encoder pointwise activation type"
,
)
group
.
add_argument
(
"--custom-enc-conv-mod-activation-type"
,
type
=
str
,
default
=
"swish"
,
choices
=
[
"relu"
,
"hardtanh"
,
"selu"
,
"swish"
],
help
=
"Custom encoder convolutional module activation type"
,
)
return
group
def
add_decoder_general_arguments
(
group
:
_ArgumentGroup
)
->
_ArgumentGroup
:
"""Define general arguments for encoder."""
group
.
add_argument
(
"--dtype"
,
default
=
"lstm"
,
type
=
str
,
choices
=
[
"lstm"
,
"gru"
,
"custom"
],
help
=
"Type of decoder to use"
,
)
group
.
add_argument
(
"--dropout-rate-decoder"
,
default
=
0.0
,
type
=
float
,
help
=
"Dropout rate for the decoder"
,
)
group
.
add_argument
(
"--dropout-rate-embed-decoder"
,
default
=
0.0
,
type
=
float
,
help
=
"Dropout rate for the decoder embedding layer"
,
)
return
group
def
add_rnn_decoder_arguments
(
group
:
_ArgumentGroup
)
->
_ArgumentGroup
:
"""Define arguments for RNN decoder."""
group
.
add_argument
(
"--dec-embed-dim"
,
default
=
320
,
type
=
int
,
help
=
"Number of decoder embeddings dimensions"
,
)
group
.
add_argument
(
"--dlayers"
,
default
=
1
,
type
=
int
,
help
=
"Number of decoder layers"
)
group
.
add_argument
(
"--dunits"
,
default
=
320
,
type
=
int
,
help
=
"Number of decoder hidden units"
)
return
group
def
add_custom_decoder_arguments
(
group
:
_ArgumentGroup
)
->
_ArgumentGroup
:
"""Define arguments for Custom decoder."""
group
.
add_argument
(
"--dec-block-arch"
,
type
=
eval
,
action
=
"append"
,
default
=
None
,
help
=
"Custom decoder blocks definition"
,
)
group
.
add_argument
(
"--dec-block-repeat"
,
default
=
1
,
type
=
int
,
help
=
"Repeat N times the provided decoder blocks if N > 1"
,
)
group
.
add_argument
(
"--custom-dec-input-layer"
,
type
=
str
,
default
=
"embed"
,
choices
=
[
"linear"
,
"embed"
],
help
=
"Custom decoder input layer type"
,
)
group
.
add_argument
(
"--custom-dec-pw-activation-type"
,
type
=
str
,
default
=
"relu"
,
choices
=
[
"relu"
,
"hardtanh"
,
"selu"
,
"swish"
],
help
=
"Custom decoder pointwise activation type"
,
)
return
group
def
add_custom_training_arguments
(
group
:
_ArgumentGroup
)
->
_ArgumentGroup
:
"""Define arguments for training with Custom architecture."""
group
.
add_argument
(
"--optimizer-warmup-steps"
,
default
=
25000
,
type
=
int
,
help
=
"Optimizer warmup steps"
,
)
group
.
add_argument
(
"--noam-lr"
,
default
=
10.0
,
type
=
float
,
help
=
"Initial value of learning rate"
,
)
group
.
add_argument
(
"--noam-adim"
,
default
=
0
,
type
=
int
,
help
=
"Most dominant attention dimension for scheduler."
,
)
group
.
add_argument
(
"--transformer-warmup-steps"
,
type
=
int
,
help
=
"Optimizer warmup steps. The parameter is deprecated, "
"please use --optimizer-warmup-steps instead."
,
dest
=
"optimizer_warmup_steps"
,
)
group
.
add_argument
(
"--transformer-lr"
,
type
=
float
,
help
=
"Initial value of learning rate. The parameter is deprecated, "
"please use --noam-lr instead."
,
dest
=
"noam_lr"
,
)
group
.
add_argument
(
"--adim"
,
type
=
int
,
help
=
"Most dominant attention dimension for scheduler. "
"The parameter is deprecated, please use --noam-adim instead."
,
dest
=
"noam_adim"
,
)
return
group
def
add_transducer_arguments
(
group
:
_ArgumentGroup
)
->
_ArgumentGroup
:
"""Define general arguments for Transducer model."""
group
.
add_argument
(
"--transducer-weight"
,
default
=
1.0
,
type
=
float
,
help
=
"Weight of main Transducer loss."
,
)
group
.
add_argument
(
"--joint-dim"
,
default
=
320
,
type
=
int
,
help
=
"Number of dimensions in joint space"
,
)
group
.
add_argument
(
"--joint-activation-type"
,
type
=
str
,
default
=
"tanh"
,
choices
=
[
"relu"
,
"tanh"
,
"swish"
],
help
=
"Joint network activation type"
,
)
group
.
add_argument
(
"--score-norm"
,
type
=
strtobool
,
nargs
=
"?"
,
default
=
True
,
help
=
"Normalize Transducer scores by length"
,
)
group
.
add_argument
(
"--fastemit-lambda"
,
default
=
0.0
,
type
=
float
,
help
=
"Regularization parameter for FastEmit (https://arxiv.org/abs/2010.11148)"
,
)
return
group
def
add_auxiliary_task_arguments
(
group
:
_ArgumentGroup
)
->
_ArgumentGroup
:
"""Add arguments for auxiliary task."""
group
.
add_argument
(
"--use-ctc-loss"
,
type
=
strtobool
,
nargs
=
"?"
,
default
=
False
,
help
=
"Whether to compute auxiliary CTC loss."
,
)
group
.
add_argument
(
"--ctc-loss-weight"
,
default
=
0.5
,
type
=
float
,
help
=
"Weight of auxiliary CTC loss."
,
)
group
.
add_argument
(
"--ctc-loss-dropout-rate"
,
default
=
0.0
,
type
=
float
,
help
=
"Dropout rate for auxiliary CTC."
,
)
group
.
add_argument
(
"--use-lm-loss"
,
type
=
strtobool
,
nargs
=
"?"
,
default
=
False
,
help
=
"Whether to compute auxiliary LM loss (label smoothing)."
,
)
group
.
add_argument
(
"--lm-loss-weight"
,
default
=
0.5
,
type
=
float
,
help
=
"Weight of auxiliary LM loss."
,
)
group
.
add_argument
(
"--lm-loss-smoothing-rate"
,
default
=
0.0
,
type
=
float
,
help
=
"Smoothing rate for LM loss. If > 0, label smoothing is enabled."
,
)
group
.
add_argument
(
"--use-aux-transducer-loss"
,
type
=
strtobool
,
nargs
=
"?"
,
default
=
False
,
help
=
"Whether to compute auxiliary Transducer loss."
,
)
group
.
add_argument
(
"--aux-transducer-loss-weight"
,
default
=
0.2
,
type
=
float
,
help
=
"Weight of auxiliary Transducer loss."
,
)
group
.
add_argument
(
"--aux-transducer-loss-enc-output-layers"
,
default
=
None
,
type
=
ast
.
literal_eval
,
help
=
"List of intermediate encoder layers for auxiliary "
"transducer loss computation."
,
)
group
.
add_argument
(
"--aux-transducer-loss-mlp-dim"
,
default
=
320
,
type
=
int
,
help
=
"Multilayer perceptron hidden dimension for auxiliary Transducer loss."
,
)
group
.
add_argument
(
"--aux-transducer-loss-mlp-dropout-rate"
,
default
=
0.0
,
type
=
float
,
help
=
"Multilayer perceptron dropout rate for auxiliary Transducer loss."
,
)
group
.
add_argument
(
"--use-symm-kl-div-loss"
,
type
=
strtobool
,
nargs
=
"?"
,
default
=
False
,
help
=
"Whether to compute symmetric KL divergence loss."
,
)
group
.
add_argument
(
"--symm-kl-div-loss-weight"
,
default
=
0.2
,
type
=
float
,
help
=
"Weight of symmetric KL divergence loss."
,
)
return
group
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/blocks.py
0 → 100644
View file @
60a2c57a
"""Set of methods to create custom architecture."""
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
import
torch
from
espnet.nets.pytorch_backend.conformer.convolution
import
ConvolutionModule
from
espnet.nets.pytorch_backend.conformer.encoder_layer
import
(
EncoderLayer
as
ConformerEncoderLayer
,
)
from
espnet.nets.pytorch_backend.nets_utils
import
get_activation
from
espnet.nets.pytorch_backend.transducer.conv1d_nets
import
CausalConv1d
,
Conv1d
from
espnet.nets.pytorch_backend.transducer.transformer_decoder_layer
import
(
TransformerDecoderLayer
,
)
from
espnet.nets.pytorch_backend.transducer.vgg2l
import
VGG2L
from
espnet.nets.pytorch_backend.transformer.attention
import
(
MultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
)
from
espnet.nets.pytorch_backend.transformer.embedding
import
(
PositionalEncoding
,
RelPositionalEncoding
,
ScaledPositionalEncoding
,
)
from
espnet.nets.pytorch_backend.transformer.encoder_layer
import
EncoderLayer
from
espnet.nets.pytorch_backend.transformer.positionwise_feed_forward
import
(
PositionwiseFeedForward
,
)
from
espnet.nets.pytorch_backend.transformer.repeat
import
MultiSequential
from
espnet.nets.pytorch_backend.transformer.subsampling
import
Conv2dSubsampling
def
verify_block_arguments
(
net_part
:
str
,
block
:
Dict
[
str
,
Any
],
num_block
:
int
,
)
->
Tuple
[
int
,
int
]:
"""Verify block arguments are valid.
Args:
net_part: Network part, either 'encoder' or 'decoder'.
block: Block parameters.
num_block: Block ID.
Return:
block_io: Input and output dimension of the block.
"""
block_type
=
block
.
get
(
"type"
)
if
block_type
is
None
:
raise
ValueError
(
"Block %d in %s doesn't a type assigned."
,
(
num_block
,
net_part
)
)
if
block_type
==
"transformer"
:
arguments
=
{
"d_hidden"
,
"d_ff"
,
"heads"
}
elif
block_type
==
"conformer"
:
arguments
=
{
"d_hidden"
,
"d_ff"
,
"heads"
,
"macaron_style"
,
"use_conv_mod"
,
}
if
net_part
==
"decoder"
:
raise
ValueError
(
"Decoder does not support 'conformer'."
)
if
block
.
get
(
"use_conv_mod"
,
None
)
is
True
and
"conv_mod_kernel"
not
in
block
:
raise
ValueError
(
"Block %d: 'use_conv_mod' is True but "
" 'conv_mod_kernel' is not specified"
%
num_block
)
elif
block_type
==
"causal-conv1d"
:
arguments
=
{
"idim"
,
"odim"
,
"kernel_size"
}
if
net_part
==
"encoder"
:
raise
ValueError
(
"Encoder does not support 'causal-conv1d'."
)
elif
block_type
==
"conv1d"
:
arguments
=
{
"idim"
,
"odim"
,
"kernel_size"
}
if
net_part
==
"decoder"
:
raise
ValueError
(
"Decoder does not support 'conv1d.'"
)
else
:
raise
NotImplementedError
(
"Wrong type. Currently supported: "
"causal-conv1d, conformer, conv-nd or transformer."
)
if
not
arguments
.
issubset
(
block
):
raise
ValueError
(
"%s in %s in position %d: Expected block arguments : %s."
" See tutorial page for more information."
%
(
block_type
,
net_part
,
num_block
,
arguments
)
)
if
block_type
in
(
"transformer"
,
"conformer"
):
block_io
=
(
block
[
"d_hidden"
],
block
[
"d_hidden"
])
else
:
block_io
=
(
block
[
"idim"
],
block
[
"odim"
])
return
block_io
def
prepare_input_layer
(
input_layer_type
:
str
,
feats_dim
:
int
,
blocks
:
List
[
Dict
[
str
,
Any
]],
dropout_rate
:
float
,
pos_enc_dropout_rate
:
float
,
)
->
Dict
[
str
,
Any
]:
"""Prepare input layer arguments.
Args:
input_layer_type: Input layer type.
feats_dim: Dimension of input features.
blocks: Blocks parameters for network part.
dropout_rate: Dropout rate for input layer.
pos_enc_dropout_rate: Dropout rate for input layer pos. enc.
Return:
input_block: Input block parameters.
"""
input_block
=
{}
first_block_type
=
blocks
[
0
].
get
(
"type"
,
None
)
if
first_block_type
==
"causal-conv1d"
:
input_block
[
"type"
]
=
"c-embed"
else
:
input_block
[
"type"
]
=
input_layer_type
input_block
[
"dropout-rate"
]
=
dropout_rate
input_block
[
"pos-dropout-rate"
]
=
pos_enc_dropout_rate
input_block
[
"idim"
]
=
feats_dim
if
first_block_type
in
(
"transformer"
,
"conformer"
):
input_block
[
"odim"
]
=
blocks
[
0
].
get
(
"d_hidden"
,
0
)
else
:
input_block
[
"odim"
]
=
blocks
[
0
].
get
(
"idim"
,
0
)
return
input_block
def
prepare_body_model
(
net_part
:
str
,
blocks
:
List
[
Dict
[
str
,
Any
]],
)
->
Tuple
[
int
]:
"""Prepare model body blocks.
Args:
net_part: Network part, either 'encoder' or 'decoder'.
blocks: Blocks parameters for network part.
Return:
: Network output dimension.
"""
cmp_io
=
[
verify_block_arguments
(
net_part
,
b
,
(
i
+
1
))
for
i
,
b
in
enumerate
(
blocks
)
]
if
{
"transformer"
,
"conformer"
}
<=
{
b
[
"type"
]
for
b
in
blocks
}:
raise
NotImplementedError
(
net_part
+
": transformer and conformer blocks "
"can't be used together in the same net part."
)
for
i
in
range
(
1
,
len
(
cmp_io
)):
if
cmp_io
[(
i
-
1
)][
1
]
!=
cmp_io
[
i
][
0
]:
raise
ValueError
(
"Output/Input mismatch between blocks %d and %d in %s"
%
(
i
,
(
i
+
1
),
net_part
)
)
return
cmp_io
[
-
1
][
1
]
def
get_pos_enc_and_att_class
(
net_part
:
str
,
pos_enc_type
:
str
,
self_attn_type
:
str
)
->
Tuple
[
Union
[
PositionalEncoding
,
ScaledPositionalEncoding
,
RelPositionalEncoding
],
Union
[
MultiHeadedAttention
,
RelPositionMultiHeadedAttention
],
]:
"""Get positional encoding and self attention module class.
Args:
net_part: Network part, either 'encoder' or 'decoder'.
pos_enc_type: Positional encoding type.
self_attn_type: Self-attention type.
Return:
pos_enc_class: Positional encoding class.
self_attn_class: Self-attention class.
"""
if
pos_enc_type
==
"abs_pos"
:
pos_enc_class
=
PositionalEncoding
elif
pos_enc_type
==
"scaled_abs_pos"
:
pos_enc_class
=
ScaledPositionalEncoding
elif
pos_enc_type
==
"rel_pos"
:
if
net_part
==
"encoder"
and
self_attn_type
!=
"rel_self_attn"
:
raise
ValueError
(
"'rel_pos' is only compatible with 'rel_self_attn'"
)
pos_enc_class
=
RelPositionalEncoding
else
:
raise
NotImplementedError
(
"pos_enc_type should be either 'abs_pos', 'scaled_abs_pos' or 'rel_pos'"
)
if
self_attn_type
==
"rel_self_attn"
:
self_attn_class
=
RelPositionMultiHeadedAttention
else
:
self_attn_class
=
MultiHeadedAttention
return
pos_enc_class
,
self_attn_class
def
build_input_layer
(
block
:
Dict
[
str
,
Any
],
pos_enc_class
:
torch
.
nn
.
Module
,
padding_idx
:
int
,
)
->
Tuple
[
Union
[
Conv2dSubsampling
,
VGG2L
,
torch
.
nn
.
Sequential
],
int
]:
"""Build input layer.
Args:
block: Architecture definition of input layer.
pos_enc_class: Positional encoding class.
padding_idx: Padding symbol ID for embedding layer (if provided).
Returns:
: Input layer module.
subsampling_factor: Subsampling factor.
"""
input_type
=
block
[
"type"
]
idim
=
block
[
"idim"
]
odim
=
block
[
"odim"
]
dropout_rate
=
block
[
"dropout-rate"
]
pos_dropout_rate
=
block
[
"pos-dropout-rate"
]
if
pos_enc_class
.
__name__
==
"RelPositionalEncoding"
:
pos_enc_class_subsampling
=
pos_enc_class
(
odim
,
pos_dropout_rate
)
else
:
pos_enc_class_subsampling
=
None
if
input_type
==
"linear"
:
return
(
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
idim
,
odim
),
torch
.
nn
.
LayerNorm
(
odim
),
torch
.
nn
.
Dropout
(
dropout_rate
),
torch
.
nn
.
ReLU
(),
pos_enc_class
(
odim
,
pos_dropout_rate
),
),
1
,
)
elif
input_type
==
"conv2d"
:
return
Conv2dSubsampling
(
idim
,
odim
,
dropout_rate
,
pos_enc_class_subsampling
),
4
elif
input_type
==
"vgg2l"
:
return
VGG2L
(
idim
,
odim
,
pos_enc_class_subsampling
),
4
elif
input_type
==
"embed"
:
return
(
torch
.
nn
.
Sequential
(
torch
.
nn
.
Embedding
(
idim
,
odim
,
padding_idx
=
padding_idx
),
pos_enc_class
(
odim
,
pos_dropout_rate
),
),
1
,
)
elif
input_type
==
"c-embed"
:
return
(
torch
.
nn
.
Sequential
(
torch
.
nn
.
Embedding
(
idim
,
odim
,
padding_idx
=
padding_idx
),
torch
.
nn
.
Dropout
(
dropout_rate
),
),
1
,
)
else
:
raise
NotImplementedError
(
"Invalid input layer: %s. Supported: linear, conv2d, vgg2l and embed"
%
input_type
)
def
build_transformer_block
(
net_part
:
str
,
block
:
Dict
[
str
,
Any
],
pw_layer_type
:
str
,
pw_activation_type
:
str
,
)
->
Union
[
EncoderLayer
,
TransformerDecoderLayer
]:
"""Build function for transformer block.
Args:
net_part: Network part, either 'encoder' or 'decoder'.
block: Transformer block parameters.
pw_layer_type: Positionwise layer type.
pw_activation_type: Positionwise activation type.
Returns:
: Function to create transformer (encoder or decoder) block.
"""
d_hidden
=
block
[
"d_hidden"
]
dropout_rate
=
block
.
get
(
"dropout-rate"
,
0.0
)
pos_dropout_rate
=
block
.
get
(
"pos-dropout-rate"
,
0.0
)
att_dropout_rate
=
block
.
get
(
"att-dropout-rate"
,
0.0
)
if
pw_layer_type
!=
"linear"
:
raise
NotImplementedError
(
"Transformer block only supports linear pointwise layer."
)
if
net_part
==
"encoder"
:
transformer_layer_class
=
EncoderLayer
elif
net_part
==
"decoder"
:
transformer_layer_class
=
TransformerDecoderLayer
return
lambda
:
transformer_layer_class
(
d_hidden
,
MultiHeadedAttention
(
block
[
"heads"
],
d_hidden
,
att_dropout_rate
),
PositionwiseFeedForward
(
d_hidden
,
block
[
"d_ff"
],
pos_dropout_rate
,
get_activation
(
pw_activation_type
),
),
dropout_rate
,
)
def
build_conformer_block
(
block
:
Dict
[
str
,
Any
],
self_attn_class
:
str
,
pw_layer_type
:
str
,
pw_activation_type
:
str
,
conv_mod_activation_type
:
str
,
)
->
ConformerEncoderLayer
:
"""Build function for conformer block.
Args:
block: Conformer block parameters.
self_attn_type: Self-attention module type.
pw_layer_type: Positionwise layer type.
pw_activation_type: Positionwise activation type.
conv_mod_activation_type: Convolutional module activation type.
Returns:
: Function to create conformer (encoder) block.
"""
d_hidden
=
block
[
"d_hidden"
]
d_ff
=
block
[
"d_ff"
]
dropout_rate
=
block
.
get
(
"dropout-rate"
,
0.0
)
pos_dropout_rate
=
block
.
get
(
"pos-dropout-rate"
,
0.0
)
att_dropout_rate
=
block
.
get
(
"att-dropout-rate"
,
0.0
)
macaron_style
=
block
[
"macaron_style"
]
use_conv_mod
=
block
[
"use_conv_mod"
]
if
pw_layer_type
==
"linear"
:
pw_layer
=
PositionwiseFeedForward
pw_layer_args
=
(
d_hidden
,
d_ff
,
pos_dropout_rate
,
get_activation
(
pw_activation_type
),
)
else
:
raise
NotImplementedError
(
"Conformer block only supports linear yet."
)
if
macaron_style
:
macaron_net
=
PositionwiseFeedForward
macaron_net_args
=
(
d_hidden
,
d_ff
,
pos_dropout_rate
,
get_activation
(
pw_activation_type
),
)
if
use_conv_mod
:
conv_mod
=
ConvolutionModule
conv_mod_args
=
(
d_hidden
,
block
[
"conv_mod_kernel"
],
get_activation
(
conv_mod_activation_type
),
)
return
lambda
:
ConformerEncoderLayer
(
d_hidden
,
self_attn_class
(
block
[
"heads"
],
d_hidden
,
att_dropout_rate
),
pw_layer
(
*
pw_layer_args
),
macaron_net
(
*
macaron_net_args
)
if
macaron_style
else
None
,
conv_mod
(
*
conv_mod_args
)
if
use_conv_mod
else
None
,
dropout_rate
,
)
def
build_conv1d_block
(
block
:
Dict
[
str
,
Any
],
block_type
:
str
)
->
CausalConv1d
:
"""Build function for causal conv1d block.
Args:
block: CausalConv1d or Conv1D block parameters.
Returns:
: Function to create conv1d (encoder) or causal conv1d (decoder) block.
"""
if
block_type
==
"conv1d"
:
conv_class
=
Conv1d
else
:
conv_class
=
CausalConv1d
stride
=
block
.
get
(
"stride"
,
1
)
dilation
=
block
.
get
(
"dilation"
,
1
)
groups
=
block
.
get
(
"groups"
,
1
)
bias
=
block
.
get
(
"bias"
,
True
)
use_batch_norm
=
block
.
get
(
"use-batch-norm"
,
False
)
use_relu
=
block
.
get
(
"use-relu"
,
False
)
dropout_rate
=
block
.
get
(
"dropout-rate"
,
0.0
)
return
lambda
:
conv_class
(
block
[
"idim"
],
block
[
"odim"
],
block
[
"kernel_size"
],
stride
=
stride
,
dilation
=
dilation
,
groups
=
groups
,
bias
=
bias
,
relu
=
use_relu
,
batch_norm
=
use_batch_norm
,
dropout_rate
=
dropout_rate
,
)
def
build_blocks
(
net_part
:
str
,
idim
:
int
,
input_layer_type
:
str
,
blocks
:
List
[
Dict
[
str
,
Any
]],
repeat_block
:
int
=
0
,
self_attn_type
:
str
=
"self_attn"
,
positional_encoding_type
:
str
=
"abs_pos"
,
positionwise_layer_type
:
str
=
"linear"
,
positionwise_activation_type
:
str
=
"relu"
,
conv_mod_activation_type
:
str
=
"relu"
,
input_layer_dropout_rate
:
float
=
0.0
,
input_layer_pos_enc_dropout_rate
:
float
=
0.0
,
padding_idx
:
int
=
-
1
,
)
->
Tuple
[
Union
[
Conv2dSubsampling
,
VGG2L
,
torch
.
nn
.
Sequential
],
MultiSequential
,
int
,
int
]:
"""Build custom model blocks.
Args:
net_part: Network part, either 'encoder' or 'decoder'.
idim: Input dimension.
input_layer: Input layer type.
blocks: Blocks parameters for network part.
repeat_block: Number of times provided blocks are repeated.
positional_encoding_type: Positional encoding layer type.
positionwise_layer_type: Positionwise layer type.
positionwise_activation_type: Positionwise activation type.
conv_mod_activation_type: Convolutional module activation type.
input_layer_dropout_rate: Dropout rate for input layer.
input_layer_pos_enc_dropout_rate: Dropout rate for input layer pos. enc.
padding_idx: Padding symbol ID for embedding layer.
Returns:
in_layer: Input layer
all_blocks: Encoder/Decoder network.
out_dim: Network output dimension.
conv_subsampling_factor: Subsampling factor in frontend CNN.
"""
fn_modules
=
[]
pos_enc_class
,
self_attn_class
=
get_pos_enc_and_att_class
(
net_part
,
positional_encoding_type
,
self_attn_type
)
input_block
=
prepare_input_layer
(
input_layer_type
,
idim
,
blocks
,
input_layer_dropout_rate
,
input_layer_pos_enc_dropout_rate
,
)
out_dim
=
prepare_body_model
(
net_part
,
blocks
)
input_layer
,
conv_subsampling_factor
=
build_input_layer
(
input_block
,
pos_enc_class
,
padding_idx
,
)
for
i
in
range
(
len
(
blocks
)):
block_type
=
blocks
[
i
][
"type"
]
if
block_type
in
(
"causal-conv1d"
,
"conv1d"
):
module
=
build_conv1d_block
(
blocks
[
i
],
block_type
)
elif
block_type
==
"conformer"
:
module
=
build_conformer_block
(
blocks
[
i
],
self_attn_class
,
positionwise_layer_type
,
positionwise_activation_type
,
conv_mod_activation_type
,
)
elif
block_type
==
"transformer"
:
module
=
build_transformer_block
(
net_part
,
blocks
[
i
],
positionwise_layer_type
,
positionwise_activation_type
,
)
fn_modules
.
append
(
module
)
if
repeat_block
>
1
:
fn_modules
=
fn_modules
*
repeat_block
return
(
input_layer
,
MultiSequential
(
*
[
fn
()
for
fn
in
fn_modules
]),
out_dim
,
conv_subsampling_factor
,
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/conv1d_nets.py
0 → 100644
View file @
60a2c57a
"""Convolution networks definition for custom archictecture."""
from
typing
import
Optional
,
Tuple
,
Union
import
torch
class
Conv1d
(
torch
.
nn
.
Module
):
"""1D convolution module for custom encoder.
Args:
idim: Input dimension.
odim: Output dimension.
kernel_size: Size of the convolving kernel.
stride: Stride of the convolution.
dilation: Spacing between the kernel points.
groups: Number of blocked connections from input channels to output channels.
bias: Whether to add a learnable bias to the output.
batch_norm: Whether to use batch normalization after convolution.
relu: Whether to use a ReLU activation after convolution.
dropout_rate: Dropout rate.
"""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
kernel_size
:
Union
[
int
,
Tuple
],
stride
:
Union
[
int
,
Tuple
]
=
1
,
dilation
:
Union
[
int
,
Tuple
]
=
1
,
groups
:
Union
[
int
,
Tuple
]
=
1
,
bias
:
bool
=
True
,
batch_norm
:
bool
=
False
,
relu
:
bool
=
True
,
dropout_rate
:
float
=
0.0
,
):
"""Construct a Conv1d module object."""
super
().
__init__
()
self
.
conv
=
torch
.
nn
.
Conv1d
(
idim
,
odim
,
kernel_size
,
stride
=
stride
,
dilation
=
dilation
,
groups
=
groups
,
bias
=
bias
,
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
p
=
dropout_rate
)
if
relu
:
self
.
relu_func
=
torch
.
nn
.
ReLU
()
if
batch_norm
:
self
.
bn
=
torch
.
nn
.
BatchNorm1d
(
odim
)
self
.
relu
=
relu
self
.
batch_norm
=
batch_norm
self
.
padding
=
dilation
*
(
kernel_size
-
1
)
self
.
stride
=
stride
self
.
out_pos
=
torch
.
nn
.
Linear
(
idim
,
odim
)
def
forward
(
self
,
sequence
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
mask
:
torch
.
Tensor
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
torch
.
Tensor
]:
"""Forward ConvEncoderLayer module object.
Args:
sequence: Input sequences.
(B, T, D_in)
or (B, T, D_in), (B, 2 * (T - 1), D_att)
mask: Mask of input sequences. (B, 1, T)
Returns:
sequence: Output sequences.
(B, sub(T), D_out)
or (B, sub(T), D_out), (B, 2 * (sub(T) - 1), D_att)
mask: Mask of output sequences. (B, 1, sub(T))
"""
if
isinstance
(
sequence
,
tuple
):
sequence
,
pos_embed
=
sequence
[
0
],
sequence
[
1
]
else
:
sequence
,
pos_embed
=
sequence
,
None
sequence
=
sequence
.
transpose
(
1
,
2
)
sequence
=
self
.
conv
(
sequence
)
if
self
.
batch_norm
:
sequence
=
self
.
bn
(
sequence
)
sequence
=
self
.
dropout
(
sequence
)
if
self
.
relu
:
sequence
=
self
.
relu_func
(
sequence
)
sequence
=
sequence
.
transpose
(
1
,
2
)
mask
=
self
.
create_new_mask
(
mask
)
if
pos_embed
is
not
None
:
pos_embed
=
self
.
create_new_pos_embed
(
pos_embed
)
return
(
sequence
,
pos_embed
),
mask
return
sequence
,
mask
def
create_new_mask
(
self
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Create new mask.
Args:
mask: Mask of input sequences. (B, 1, T)
Returns:
mask: Mask of output sequences. (B, 1, sub(T))
"""
if
mask
is
None
:
return
mask
if
self
.
padding
!=
0
:
mask
=
mask
[:,
:,
:
-
self
.
padding
]
mask
=
mask
[:,
:,
::
self
.
stride
]
return
mask
def
create_new_pos_embed
(
self
,
pos_embed
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Create new positional embedding vector.
Args:
pos_embed: Input sequences positional embedding.
(B, 2 * (T - 1), D_att)
Return:
pos_embed: Output sequences positional embedding.
(B, 2 * (sub(T) - 1), D_att)
"""
pos_embed_positive
=
pos_embed
[:,
:
pos_embed
.
size
(
1
)
//
2
+
1
,
:]
pos_embed_negative
=
pos_embed
[:,
pos_embed
.
size
(
1
)
//
2
:,
:]
if
self
.
padding
!=
0
:
pos_embed_positive
=
pos_embed_positive
[:,
:
-
self
.
padding
,
:]
pos_embed_negative
=
pos_embed_negative
[:,
:
-
self
.
padding
,
:]
pos_embed_positive
=
pos_embed_positive
[:,
::
self
.
stride
,
:]
pos_embed_negative
=
pos_embed_negative
[:,
::
self
.
stride
,
:]
pos_embed
=
torch
.
cat
([
pos_embed_positive
,
pos_embed_negative
[:,
1
:,
:]],
dim
=
1
)
return
self
.
out_pos
(
pos_embed
)
class
CausalConv1d
(
torch
.
nn
.
Module
):
"""1D causal convolution module for custom decoder.
Args:
idim: Input dimension.
odim: Output dimension.
kernel_size: Size of the convolving kernel.
stride: Stride of the convolution.
dilation: Spacing between the kernel points.
groups: Number of blocked connections from input channels to output channels.
bias: Whether to add a learnable bias to the output.
batch_norm: Whether to apply batch normalization.
relu: Whether to pass final output through ReLU activation.
dropout_rate: Dropout rate.
"""
def
__init__
(
self
,
idim
:
int
,
odim
:
int
,
kernel_size
:
int
,
stride
:
int
=
1
,
dilation
:
int
=
1
,
groups
:
int
=
1
,
bias
:
bool
=
True
,
batch_norm
:
bool
=
False
,
relu
:
bool
=
True
,
dropout_rate
:
float
=
0.0
,
):
"""Construct a CausalConv1d object."""
super
().
__init__
()
self
.
padding
=
(
kernel_size
-
1
)
*
dilation
self
.
causal_conv1d
=
torch
.
nn
.
Conv1d
(
idim
,
odim
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
self
.
padding
,
dilation
=
dilation
,
groups
=
groups
,
bias
=
bias
,
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
p
=
dropout_rate
)
if
batch_norm
:
self
.
bn
=
torch
.
nn
.
BatchNorm1d
(
odim
)
if
relu
:
self
.
relu_func
=
torch
.
nn
.
ReLU
()
self
.
batch_norm
=
batch_norm
self
.
relu
=
relu
def
forward
(
self
,
sequence
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
cache
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Forward CausalConv1d for custom decoder.
Args:
sequence: CausalConv1d input sequences. (B, U, D_in)
mask: Mask of CausalConv1d input sequences. (B, 1, U)
Returns:
sequence: CausalConv1d output sequences. (B, sub(U), D_out)
mask: Mask of CausalConv1d output sequences. (B, 1, sub(U))
"""
sequence
=
sequence
.
transpose
(
1
,
2
)
sequence
=
self
.
causal_conv1d
(
sequence
)
if
self
.
padding
!=
0
:
sequence
=
sequence
[:,
:,
:
-
self
.
padding
]
if
self
.
batch_norm
:
sequence
=
self
.
bn
(
sequence
)
sequence
=
self
.
dropout
(
sequence
)
if
self
.
relu
:
sequence
=
self
.
relu_func
(
sequence
)
sequence
=
sequence
.
transpose
(
1
,
2
)
return
sequence
,
mask
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/custom_decoder.py
0 → 100644
View file @
60a2c57a
"""Custom decoder definition for Transducer model."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
espnet.nets.pytorch_backend.transducer.blocks
import
build_blocks
from
espnet.nets.pytorch_backend.transducer.utils
import
(
check_batch_states
,
check_state
,
pad_sequence
,
)
from
espnet.nets.pytorch_backend.transformer.layer_norm
import
LayerNorm
from
espnet.nets.pytorch_backend.transformer.mask
import
subsequent_mask
from
espnet.nets.transducer_decoder_interface
import
(
ExtendedHypothesis
,
Hypothesis
,
TransducerDecoderInterface
,
)
class
CustomDecoder
(
TransducerDecoderInterface
,
torch
.
nn
.
Module
):
"""Custom decoder module for Transducer model.
Args:
odim: Output dimension.
dec_arch: Decoder block architecture (type and parameters).
input_layer: Input layer type.
repeat_block: Number of times dec_arch is repeated.
joint_activation_type: Type of activation for joint network.
positional_encoding_type: Positional encoding type.
positionwise_layer_type: Positionwise layer type.
positionwise_activation_type: Positionwise activation type.
input_layer_dropout_rate: Dropout rate for input layer.
blank_id: Blank symbol ID.
"""
def
__init__
(
self
,
odim
:
int
,
dec_arch
:
List
,
input_layer
:
str
=
"embed"
,
repeat_block
:
int
=
0
,
joint_activation_type
:
str
=
"tanh"
,
positional_encoding_type
:
str
=
"abs_pos"
,
positionwise_layer_type
:
str
=
"linear"
,
positionwise_activation_type
:
str
=
"relu"
,
input_layer_dropout_rate
:
float
=
0.0
,
blank_id
:
int
=
0
,
):
"""Construct a CustomDecoder object."""
torch
.
nn
.
Module
.
__init__
(
self
)
self
.
embed
,
self
.
decoders
,
ddim
,
_
=
build_blocks
(
"decoder"
,
odim
,
input_layer
,
dec_arch
,
repeat_block
=
repeat_block
,
positional_encoding_type
=
positional_encoding_type
,
positionwise_layer_type
=
positionwise_layer_type
,
positionwise_activation_type
=
positionwise_activation_type
,
input_layer_dropout_rate
=
input_layer_dropout_rate
,
padding_idx
=
blank_id
,
)
self
.
after_norm
=
LayerNorm
(
ddim
)
self
.
dlayers
=
len
(
self
.
decoders
)
self
.
dunits
=
ddim
self
.
odim
=
odim
self
.
blank_id
=
blank_id
def
set_device
(
self
,
device
:
torch
.
device
):
"""Set GPU device to use.
Args:
device: Device ID.
"""
self
.
device
=
device
def
init_state
(
self
,
batch_size
:
Optional
[
int
]
=
None
,
)
->
List
[
Optional
[
torch
.
Tensor
]]:
"""Initialize decoder states.
Args:
batch_size: Batch size.
Returns:
state: Initial decoder hidden states. [N x None]
"""
state
=
[
None
]
*
self
.
dlayers
return
state
def
forward
(
self
,
dec_input
:
torch
.
Tensor
,
dec_mask
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Encode label ID sequences.
Args:
dec_input: Label ID sequences. (B, U)
dec_mask: Label mask sequences. (B, U)
Return:
dec_output: Decoder output sequences. (B, U, D_dec)
dec_output_mask: Mask of decoder output sequences. (B, U)
"""
dec_input
=
self
.
embed
(
dec_input
)
dec_output
,
dec_mask
=
self
.
decoders
(
dec_input
,
dec_mask
)
dec_output
=
self
.
after_norm
(
dec_output
)
return
dec_output
,
dec_mask
def
score
(
self
,
hyp
:
Hypothesis
,
cache
:
Dict
[
str
,
Any
]
)
->
Tuple
[
torch
.
Tensor
,
List
[
Optional
[
torch
.
Tensor
]],
torch
.
Tensor
]:
"""One-step forward hypothesis.
Args:
hyp: Hypothesis.
cache: Pairs of (dec_out, dec_state) for each label sequence. (key)
Returns:
dec_out: Decoder output sequence. (1, D_dec)
dec_state: Decoder hidden states. [N x (1, U, D_dec)]
lm_label: Label ID for LM. (1,)
"""
labels
=
torch
.
tensor
([
hyp
.
yseq
],
device
=
self
.
device
)
lm_label
=
labels
[:,
-
1
]
str_labels
=
"_"
.
join
(
list
(
map
(
str
,
hyp
.
yseq
)))
if
str_labels
in
cache
:
dec_out
,
dec_state
=
cache
[
str_labels
]
else
:
dec_out_mask
=
subsequent_mask
(
len
(
hyp
.
yseq
)).
unsqueeze_
(
0
)
new_state
=
check_state
(
hyp
.
dec_state
,
(
labels
.
size
(
1
)
-
1
),
self
.
blank_id
)
dec_out
=
self
.
embed
(
labels
)
dec_state
=
[]
for
s
,
decoder
in
zip
(
new_state
,
self
.
decoders
):
dec_out
,
dec_out_mask
=
decoder
(
dec_out
,
dec_out_mask
,
cache
=
s
)
dec_state
.
append
(
dec_out
)
dec_out
=
self
.
after_norm
(
dec_out
[:,
-
1
])
cache
[
str_labels
]
=
(
dec_out
,
dec_state
)
return
dec_out
[
0
],
dec_state
,
lm_label
def
batch_score
(
self
,
hyps
:
Union
[
List
[
Hypothesis
],
List
[
ExtendedHypothesis
]],
dec_states
:
List
[
Optional
[
torch
.
Tensor
]],
cache
:
Dict
[
str
,
Any
],
use_lm
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
List
[
Optional
[
torch
.
Tensor
]],
torch
.
Tensor
]:
"""One-step forward hypotheses.
Args:
hyps: Hypotheses.
dec_states: Decoder hidden states. [N x (B, U, D_dec)]
cache: Pairs of (h_dec, dec_states) for each label sequences. (keys)
use_lm: Whether to compute label ID sequences for LM.
Returns:
dec_out: Decoder output sequences. (B, D_dec)
dec_states: Decoder hidden states. [N x (B, U, D_dec)]
lm_labels: Label ID sequences for LM. (B,)
"""
final_batch
=
len
(
hyps
)
process
=
[]
done
=
[
None
]
*
final_batch
for
i
,
hyp
in
enumerate
(
hyps
):
str_labels
=
"_"
.
join
(
list
(
map
(
str
,
hyp
.
yseq
)))
if
str_labels
in
cache
:
done
[
i
]
=
cache
[
str_labels
]
else
:
process
.
append
((
str_labels
,
hyp
.
yseq
,
hyp
.
dec_state
))
if
process
:
labels
=
pad_sequence
([
p
[
1
]
for
p
in
process
],
self
.
blank_id
)
labels
=
torch
.
LongTensor
(
labels
,
device
=
self
.
device
)
p_dec_states
=
self
.
create_batch_states
(
self
.
init_state
(),
[
p
[
2
]
for
p
in
process
],
labels
,
)
dec_out
=
self
.
embed
(
labels
)
dec_out_mask
=
(
subsequent_mask
(
labels
.
size
(
-
1
))
.
unsqueeze_
(
0
)
.
expand
(
len
(
process
),
-
1
,
-
1
)
)
new_states
=
[]
for
s
,
decoder
in
zip
(
p_dec_states
,
self
.
decoders
):
dec_out
,
dec_out_mask
=
decoder
(
dec_out
,
dec_out_mask
,
cache
=
s
)
new_states
.
append
(
dec_out
)
dec_out
=
self
.
after_norm
(
dec_out
[:,
-
1
])
j
=
0
for
i
in
range
(
final_batch
):
if
done
[
i
]
is
None
:
state
=
self
.
select_state
(
new_states
,
j
)
done
[
i
]
=
(
dec_out
[
j
],
state
)
cache
[
process
[
j
][
0
]]
=
(
dec_out
[
j
],
state
)
j
+=
1
dec_out
=
torch
.
stack
([
d
[
0
]
for
d
in
done
])
dec_states
=
self
.
create_batch_states
(
dec_states
,
[
d
[
1
]
for
d
in
done
],
[[
0
]
+
h
.
yseq
for
h
in
hyps
]
)
if
use_lm
:
lm_labels
=
torch
.
LongTensor
(
[
hyp
.
yseq
[
-
1
]
for
hyp
in
hyps
],
device
=
self
.
device
)
return
dec_out
,
dec_states
,
lm_labels
return
dec_out
,
dec_states
,
None
def
select_state
(
self
,
states
:
List
[
Optional
[
torch
.
Tensor
]],
idx
:
int
)
->
List
[
Optional
[
torch
.
Tensor
]]:
"""Get specified ID state from decoder hidden states.
Args:
states: Decoder hidden states. [N x (B, U, D_dec)]
idx: State ID to extract.
Returns:
state_idx: Decoder hidden state for given ID. [N x (1, U, D_dec)]
"""
if
states
[
0
]
is
None
:
return
states
state_idx
=
[
states
[
layer
][
idx
]
for
layer
in
range
(
self
.
dlayers
)]
return
state_idx
def
create_batch_states
(
self
,
states
:
List
[
Optional
[
torch
.
Tensor
]],
new_states
:
List
[
Optional
[
torch
.
Tensor
]],
check_list
:
List
[
List
[
int
]],
)
->
List
[
Optional
[
torch
.
Tensor
]]:
"""Create decoder hidden states sequences.
Args:
states: Decoder hidden states. [N x (B, U, D_dec)]
new_states: Decoder hidden states. [B x [N x (1, U, D_dec)]]
check_list: Label ID sequences.
Returns:
states: New decoder hidden states. [N x (B, U, D_dec)]
"""
if
new_states
[
0
][
0
]
is
None
:
return
states
max_len
=
max
(
len
(
elem
)
for
elem
in
check_list
)
-
1
for
layer
in
range
(
self
.
dlayers
):
states
[
layer
]
=
check_batch_states
(
[
s
[
layer
]
for
s
in
new_states
],
max_len
,
self
.
blank_id
)
return
states
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/custom_encoder.py
0 → 100644
View file @
60a2c57a
"""Cutom encoder definition for transducer models."""
from
typing
import
List
,
Tuple
,
Union
import
torch
from
espnet.nets.pytorch_backend.transducer.blocks
import
build_blocks
from
espnet.nets.pytorch_backend.transducer.vgg2l
import
VGG2L
from
espnet.nets.pytorch_backend.transformer.layer_norm
import
LayerNorm
from
espnet.nets.pytorch_backend.transformer.subsampling
import
Conv2dSubsampling
class
CustomEncoder
(
torch
.
nn
.
Module
):
"""Custom encoder module for transducer models.
Args:
idim: Input dimension.
enc_arch: Encoder block architecture (type and parameters).
input_layer: Input layer type.
repeat_block: Number of times blocks_arch is repeated.
self_attn_type: Self-attention type.
positional_encoding_type: Positional encoding type.
positionwise_layer_type: Positionwise layer type.
positionwise_activation_type: Positionwise activation type.
conv_mod_activation_type: Convolutional module activation type.
aux_enc_output_layers: Layer IDs for auxiliary encoder output sequences.
input_layer_dropout_rate: Dropout rate for input layer.
input_layer_pos_enc_dropout_rate: Dropout rate for input layer pos. enc.
padding_idx: Padding symbol ID for embedding layer.
"""
def
__init__
(
self
,
idim
:
int
,
enc_arch
:
List
,
input_layer
:
str
=
"linear"
,
repeat_block
:
int
=
1
,
self_attn_type
:
str
=
"selfattn"
,
positional_encoding_type
:
str
=
"abs_pos"
,
positionwise_layer_type
:
str
=
"linear"
,
positionwise_activation_type
:
str
=
"relu"
,
conv_mod_activation_type
:
str
=
"relu"
,
aux_enc_output_layers
:
List
=
[],
input_layer_dropout_rate
:
float
=
0.0
,
input_layer_pos_enc_dropout_rate
:
float
=
0.0
,
padding_idx
:
int
=
-
1
,
):
"""Construct an CustomEncoder object."""
super
().
__init__
()
(
self
.
embed
,
self
.
encoders
,
self
.
enc_out
,
self
.
conv_subsampling_factor
,
)
=
build_blocks
(
"encoder"
,
idim
,
input_layer
,
enc_arch
,
repeat_block
=
repeat_block
,
self_attn_type
=
self_attn_type
,
positional_encoding_type
=
positional_encoding_type
,
positionwise_layer_type
=
positionwise_layer_type
,
positionwise_activation_type
=
positionwise_activation_type
,
conv_mod_activation_type
=
conv_mod_activation_type
,
input_layer_dropout_rate
=
input_layer_dropout_rate
,
input_layer_pos_enc_dropout_rate
=
input_layer_pos_enc_dropout_rate
,
padding_idx
=
padding_idx
,
)
self
.
after_norm
=
LayerNorm
(
self
.
enc_out
)
self
.
n_blocks
=
len
(
enc_arch
)
*
repeat_block
self
.
aux_enc_output_layers
=
aux_enc_output_layers
def
forward
(
self
,
feats
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
]]:
"""Encode feature sequences.
Args:
feats: Feature sequences. (B, F, D_feats)
feats_mask: Feature mask sequences. (B, 1, F)
Returns:
enc_out: Encoder output sequences. (B, T, D_enc) with/without
Auxiliary encoder output sequences. (B, T, D_enc_aux)
enc_out_mask: Mask for encoder output sequences. (B, 1, T) with/without
Mask for auxiliary encoder output sequences. (B, T, D_enc_aux)
"""
if
isinstance
(
self
.
embed
,
(
Conv2dSubsampling
,
VGG2L
)):
enc_out
,
mask
=
self
.
embed
(
feats
,
mask
)
else
:
enc_out
=
self
.
embed
(
feats
)
if
self
.
aux_enc_output_layers
:
aux_custom_outputs
=
[]
aux_custom_lens
=
[]
for
b
in
range
(
self
.
n_blocks
):
enc_out
,
mask
=
self
.
encoders
[
b
](
enc_out
,
mask
)
if
b
in
self
.
aux_enc_output_layers
:
if
isinstance
(
enc_out
,
tuple
):
aux_custom_output
=
enc_out
[
0
]
else
:
aux_custom_output
=
enc_out
aux_custom_outputs
.
append
(
self
.
after_norm
(
aux_custom_output
))
aux_custom_lens
.
append
(
mask
)
else
:
enc_out
,
mask
=
self
.
encoders
(
enc_out
,
mask
)
if
isinstance
(
enc_out
,
tuple
):
enc_out
=
enc_out
[
0
]
enc_out
=
self
.
after_norm
(
enc_out
)
if
self
.
aux_enc_output_layers
:
return
(
enc_out
,
aux_custom_outputs
),
(
mask
,
aux_custom_lens
)
return
enc_out
,
mask
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/error_calculator.py
0 → 100644
View file @
60a2c57a
"""CER/WER computation for Transducer model."""
from
typing
import
List
,
Tuple
,
Union
import
torch
from
espnet.nets.beam_search_transducer
import
BeamSearchTransducer
from
espnet.nets.pytorch_backend.transducer.custom_decoder
import
CustomDecoder
from
espnet.nets.pytorch_backend.transducer.joint_network
import
JointNetwork
from
espnet.nets.pytorch_backend.transducer.rnn_decoder
import
RNNDecoder
class
ErrorCalculator
(
object
):
"""CER and WER computation for Transducer model.
Args:
decoder: Decoder module.
joint_network: Joint network module.
token_list: Set of unique labels.
sym_space: Space symbol.
sym_blank: Blank symbol.
report_cer: Whether to compute CER.
report_wer: Whether to compute WER.
"""
def
__init__
(
self
,
decoder
:
Union
[
RNNDecoder
,
CustomDecoder
],
joint_network
:
JointNetwork
,
token_list
:
List
[
int
],
sym_space
:
str
,
sym_blank
:
str
,
report_cer
:
bool
=
False
,
report_wer
:
bool
=
False
,
):
"""Construct an ErrorCalculator object for Transducer model."""
super
().
__init__
()
self
.
beam_search
=
BeamSearchTransducer
(
decoder
=
decoder
,
joint_network
=
joint_network
,
beam_size
=
2
,
search_type
=
"default"
,
)
self
.
decoder
=
decoder
self
.
token_list
=
token_list
self
.
space
=
sym_space
self
.
blank
=
sym_blank
self
.
report_cer
=
report_cer
self
.
report_wer
=
report_wer
def
__call__
(
self
,
enc_out
:
torch
.
Tensor
,
target
:
torch
.
Tensor
)
->
Tuple
[
float
,
float
]:
"""Calculate sentence-level CER/WER score for hypotheses sequences.
Args:
enc_out: Encoder output sequences. (B, T, D_enc)
target: Target label ID sequences. (B, L)
Returns:
cer: Sentence-level CER score.
wer: Sentence-level WER score.
"""
cer
,
wer
=
None
,
None
batchsize
=
int
(
enc_out
.
size
(
0
))
batch_nbest
=
[]
enc_out
=
enc_out
.
to
(
next
(
self
.
decoder
.
parameters
()).
device
)
for
b
in
range
(
batchsize
):
nbest_hyps
=
self
.
beam_search
(
enc_out
[
b
])
batch_nbest
.
append
(
nbest_hyps
[
-
1
])
batch_nbest
=
[
nbest_hyp
.
yseq
[
1
:]
for
nbest_hyp
in
batch_nbest
]
hyps
,
refs
=
self
.
convert_to_char
(
batch_nbest
,
target
.
cpu
())
if
self
.
report_cer
:
cer
=
self
.
calculate_cer
(
hyps
,
refs
)
if
self
.
report_wer
:
wer
=
self
.
calculate_wer
(
hyps
,
refs
)
return
cer
,
wer
def
convert_to_char
(
self
,
hyps
:
torch
.
Tensor
,
refs
:
torch
.
Tensor
)
->
Tuple
[
List
,
List
]:
"""Convert label ID sequences to character.
Args:
hyps: Hypotheses sequences. (B, L)
refs: References sequences. (B, L)
Returns:
char_hyps: Character list of hypotheses.
char_hyps: Character list of references.
"""
char_hyps
,
char_refs
=
[],
[]
for
i
,
hyp
in
enumerate
(
hyps
):
hyp_i
=
[
self
.
token_list
[
int
(
h
)]
for
h
in
hyp
]
ref_i
=
[
self
.
token_list
[
int
(
r
)]
for
r
in
refs
[
i
]]
char_hyp
=
""
.
join
(
hyp_i
).
replace
(
self
.
space
,
" "
)
char_hyp
=
char_hyp
.
replace
(
self
.
blank
,
""
)
char_ref
=
""
.
join
(
ref_i
).
replace
(
self
.
space
,
" "
)
char_hyps
.
append
(
char_hyp
)
char_refs
.
append
(
char_ref
)
return
char_hyps
,
char_refs
def
calculate_cer
(
self
,
hyps
:
torch
.
Tensor
,
refs
:
torch
.
Tensor
)
->
float
:
"""Calculate sentence-level CER score.
Args:
hyps: Hypotheses sequences. (B, L)
refs: References sequences. (B, L)
Returns:
: Average sentence-level CER score.
"""
import
editdistance
distances
,
lens
=
[],
[]
for
i
,
hyp
in
enumerate
(
hyps
):
char_hyp
=
hyp
.
replace
(
" "
,
""
)
char_ref
=
refs
[
i
].
replace
(
" "
,
""
)
distances
.
append
(
editdistance
.
eval
(
char_hyp
,
char_ref
))
lens
.
append
(
len
(
char_ref
))
return
float
(
sum
(
distances
))
/
sum
(
lens
)
def
calculate_wer
(
self
,
hyps
:
torch
.
Tensor
,
refs
:
torch
.
Tensor
)
->
float
:
"""Calculate sentence-level WER score.
Args:
hyps: Hypotheses sequences. (B, L)
refs: References sequences. (B, L)
Returns:
: Average sentence-level WER score.
"""
import
editdistance
distances
,
lens
=
[],
[]
for
i
,
hyp
in
enumerate
(
hyps
):
word_hyp
=
hyp
.
split
()
word_ref
=
refs
[
i
].
split
()
distances
.
append
(
editdistance
.
eval
(
word_hyp
,
word_ref
))
lens
.
append
(
len
(
word_ref
))
return
float
(
sum
(
distances
))
/
sum
(
lens
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/initializer.py
0 → 100644
View file @
60a2c57a
"""Parameter initialization for Transducer model."""
import
math
from
argparse
import
Namespace
import
torch
from
espnet.nets.pytorch_backend.initialization
import
set_forget_bias_to_one
def
initializer
(
model
:
torch
.
nn
.
Module
,
args
:
Namespace
):
"""Initialize Transducer model.
Args:
model: Transducer model.
args: Namespace containing model options.
"""
for
name
,
p
in
model
.
named_parameters
():
if
any
(
x
in
name
for
x
in
[
"enc."
,
"dec."
,
"transducer_tasks."
]):
if
p
.
dim
()
==
1
:
# bias
p
.
data
.
zero_
()
elif
p
.
dim
()
==
2
:
# linear weight
n
=
p
.
size
(
1
)
stdv
=
1.0
/
math
.
sqrt
(
n
)
p
.
data
.
normal_
(
0
,
stdv
)
elif
p
.
dim
()
in
(
3
,
4
):
# conv weight
n
=
p
.
size
(
1
)
for
k
in
p
.
size
()[
2
:]:
n
*=
k
stdv
=
1.0
/
math
.
sqrt
(
n
)
p
.
data
.
normal_
(
0
,
stdv
)
if
args
.
dtype
!=
"custom"
:
model
.
dec
.
embed
.
weight
.
data
.
normal_
(
0
,
1
)
for
i
in
range
(
model
.
dec
.
dlayers
):
set_forget_bias_to_one
(
getattr
(
model
.
dec
.
decoder
[
i
],
"bias_ih_l0"
))
set_forget_bias_to_one
(
getattr
(
model
.
dec
.
decoder
[
i
],
"bias_hh_l0"
))
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/joint_network.py
0 → 100644
View file @
60a2c57a
"""Transducer joint network implementation."""
import
torch
from
espnet.nets.pytorch_backend.nets_utils
import
get_activation
class
JointNetwork
(
torch
.
nn
.
Module
):
"""Transducer joint network module.
Args:
joint_output_size: Joint network output dimension
encoder_output_size: Encoder output dimension.
decoder_output_size: Decoder output dimension.
joint_space_size: Dimension of joint space.
joint_activation_type: Type of activation for joint network.
"""
def
__init__
(
self
,
joint_output_size
:
int
,
encoder_output_size
:
int
,
decoder_output_size
:
int
,
joint_space_size
:
int
,
joint_activation_type
:
int
,
):
"""Joint network initializer."""
super
().
__init__
()
self
.
lin_enc
=
torch
.
nn
.
Linear
(
encoder_output_size
,
joint_space_size
)
self
.
lin_dec
=
torch
.
nn
.
Linear
(
decoder_output_size
,
joint_space_size
,
bias
=
False
)
self
.
lin_out
=
torch
.
nn
.
Linear
(
joint_space_size
,
joint_output_size
)
self
.
joint_activation
=
get_activation
(
joint_activation_type
)
def
forward
(
self
,
enc_out
:
torch
.
Tensor
,
dec_out
:
torch
.
Tensor
,
is_aux
:
bool
=
False
,
quantization
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""Joint computation of encoder and decoder hidden state sequences.
Args:
enc_out: Expanded encoder output state sequences (B, T, 1, D_enc)
dec_out: Expanded decoder output state sequences (B, 1, U, D_dec)
is_aux: Whether auxiliary tasks in used.
quantization: Whether dynamic quantization is used.
Returns:
joint_out: Joint output state sequences. (B, T, U, D_out)
"""
if
is_aux
:
joint_out
=
self
.
joint_activation
(
enc_out
+
self
.
lin_dec
(
dec_out
))
elif
quantization
:
joint_out
=
self
.
joint_activation
(
self
.
lin_enc
(
enc_out
.
unsqueeze
(
0
))
+
self
.
lin_dec
(
dec_out
.
unsqueeze
(
0
))
)
return
self
.
lin_out
(
joint_out
)[
0
]
else
:
joint_out
=
self
.
joint_activation
(
self
.
lin_enc
(
enc_out
)
+
self
.
lin_dec
(
dec_out
)
)
joint_out
=
self
.
lin_out
(
joint_out
)
return
joint_out
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/rnn_decoder.py
0 → 100644
View file @
60a2c57a
"""RNN decoder definition for Transducer model."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
espnet.nets.transducer_decoder_interface
import
(
ExtendedHypothesis
,
Hypothesis
,
TransducerDecoderInterface
,
)
class
RNNDecoder
(
TransducerDecoderInterface
,
torch
.
nn
.
Module
):
"""RNN decoder module for Transducer model.
Args:
odim: Output dimension.
dtype: Decoder units type.
dlayers: Number of decoder layers.
dunits: Number of decoder units per layer..
embed_dim: Embedding layer dimension.
dropout_rate: Dropout rate for decoder layers.
dropout_rate_embed: Dropout rate for embedding layer.
blank_id: Blank symbol ID.
"""
def
__init__
(
self
,
odim
:
int
,
dtype
:
str
,
dlayers
:
int
,
dunits
:
int
,
embed_dim
:
int
,
dropout_rate
:
float
=
0.0
,
dropout_rate_embed
:
float
=
0.0
,
blank_id
:
int
=
0
,
):
"""Transducer initializer."""
super
().
__init__
()
self
.
embed
=
torch
.
nn
.
Embedding
(
odim
,
embed_dim
,
padding_idx
=
blank_id
)
self
.
dropout_embed
=
torch
.
nn
.
Dropout
(
p
=
dropout_rate_embed
)
dec_net
=
torch
.
nn
.
LSTM
if
dtype
==
"lstm"
else
torch
.
nn
.
GRU
self
.
decoder
=
torch
.
nn
.
ModuleList
(
[
dec_net
(
embed_dim
,
dunits
,
1
,
batch_first
=
True
)]
)
self
.
dropout_dec
=
torch
.
nn
.
Dropout
(
p
=
dropout_rate
)
for
_
in
range
(
1
,
dlayers
):
self
.
decoder
+=
[
dec_net
(
dunits
,
dunits
,
1
,
batch_first
=
True
)]
self
.
dlayers
=
dlayers
self
.
dunits
=
dunits
self
.
dtype
=
dtype
self
.
odim
=
odim
self
.
ignore_id
=
-
1
self
.
blank_id
=
blank_id
self
.
multi_gpus
=
torch
.
cuda
.
device_count
()
>
1
def
set_device
(
self
,
device
:
torch
.
device
):
"""Set GPU device to use.
Args:
device: Device ID.
"""
self
.
device
=
device
def
init_state
(
self
,
batch_size
:
int
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
tensor
]]:
"""Initialize decoder states.
Args:
batch_size: Batch size.
Returns:
: Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
"""
h_n
=
torch
.
zeros
(
self
.
dlayers
,
batch_size
,
self
.
dunits
,
device
=
self
.
device
,
)
if
self
.
dtype
==
"lstm"
:
c_n
=
torch
.
zeros
(
self
.
dlayers
,
batch_size
,
self
.
dunits
,
device
=
self
.
device
,
)
return
(
h_n
,
c_n
)
return
(
h_n
,
None
)
def
rnn_forward
(
self
,
sequence
:
torch
.
Tensor
,
state
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]]:
"""Encode source label sequences.
Args:
sequence: RNN input sequences. (B, D_emb)
state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
Returns:
sequence: RNN output sequences. (B, D_dec)
(h_next, c_next): Decoder hidden states. (N, B, D_dec), (N, B, D_dec))
"""
h_prev
,
c_prev
=
state
h_next
,
c_next
=
self
.
init_state
(
sequence
.
size
(
0
))
for
layer
in
range
(
self
.
dlayers
):
if
self
.
dtype
==
"lstm"
:
(
sequence
,
(
h_next
[
layer
:
layer
+
1
],
c_next
[
layer
:
layer
+
1
],
),
)
=
self
.
decoder
[
layer
](
sequence
,
hx
=
(
h_prev
[
layer
:
layer
+
1
],
c_prev
[
layer
:
layer
+
1
])
)
else
:
sequence
,
h_next
[
layer
:
layer
+
1
]
=
self
.
decoder
[
layer
](
sequence
,
hx
=
h_prev
[
layer
:
layer
+
1
]
)
sequence
=
self
.
dropout_dec
(
sequence
)
return
sequence
,
(
h_next
,
c_next
)
def
forward
(
self
,
labels
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Encode source label sequences.
Args:
labels: Label ID sequences. (B, L)
Returns:
dec_out: Decoder output sequences. (B, T, U, D_dec)
"""
init_state
=
self
.
init_state
(
labels
.
size
(
0
))
dec_embed
=
self
.
dropout_embed
(
self
.
embed
(
labels
))
dec_out
,
_
=
self
.
rnn_forward
(
dec_embed
,
init_state
)
return
dec_out
def
score
(
self
,
hyp
:
Hypothesis
,
cache
:
Dict
[
str
,
Any
]
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
torch
.
Tensor
]:
"""One-step forward hypothesis.
Args:
hyp: Hypothesis.
cache: Pairs of (dec_out, state) for each label sequence. (key)
Returns:
dec_out: Decoder output sequence. (1, D_dec)
new_state: Decoder hidden states. ((N, 1, D_dec), (N, 1, D_dec))
label: Label ID for LM. (1,)
"""
label
=
torch
.
full
((
1
,
1
),
hyp
.
yseq
[
-
1
],
dtype
=
torch
.
long
,
device
=
self
.
device
)
str_labels
=
"_"
.
join
(
list
(
map
(
str
,
hyp
.
yseq
)))
if
str_labels
in
cache
:
dec_out
,
dec_state
=
cache
[
str_labels
]
else
:
dec_emb
=
self
.
embed
(
label
)
dec_out
,
dec_state
=
self
.
rnn_forward
(
dec_emb
,
hyp
.
dec_state
)
cache
[
str_labels
]
=
(
dec_out
,
dec_state
)
return
dec_out
[
0
][
0
],
dec_state
,
label
[
0
]
def
batch_score
(
self
,
hyps
:
Union
[
List
[
Hypothesis
],
List
[
ExtendedHypothesis
]],
dec_states
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
cache
:
Dict
[
str
,
Any
],
use_lm
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
torch
.
Tensor
]:
"""One-step forward hypotheses.
Args:
hyps: Hypotheses.
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
cache: Pairs of (dec_out, dec_states) for each label sequences. (keys)
use_lm: Whether to compute label ID sequences for LM.
Returns:
dec_out: Decoder output sequences. (B, D_dec)
dec_states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
lm_labels: Label ID sequences for LM. (B,)
"""
final_batch
=
len
(
hyps
)
process
=
[]
done
=
[
None
]
*
final_batch
for
i
,
hyp
in
enumerate
(
hyps
):
str_labels
=
"_"
.
join
(
list
(
map
(
str
,
hyp
.
yseq
)))
if
str_labels
in
cache
:
done
[
i
]
=
cache
[
str_labels
]
else
:
process
.
append
((
str_labels
,
hyp
.
yseq
[
-
1
],
hyp
.
dec_state
))
if
process
:
labels
=
torch
.
LongTensor
([[
p
[
1
]]
for
p
in
process
],
device
=
self
.
device
)
p_dec_states
=
self
.
create_batch_states
(
self
.
init_state
(
labels
.
size
(
0
)),
[
p
[
2
]
for
p
in
process
]
)
dec_emb
=
self
.
embed
(
labels
)
dec_out
,
new_states
=
self
.
rnn_forward
(
dec_emb
,
p_dec_states
)
j
=
0
for
i
in
range
(
final_batch
):
if
done
[
i
]
is
None
:
state
=
self
.
select_state
(
new_states
,
j
)
done
[
i
]
=
(
dec_out
[
j
],
state
)
cache
[
process
[
j
][
0
]]
=
(
dec_out
[
j
],
state
)
j
+=
1
dec_out
=
torch
.
cat
([
d
[
0
]
for
d
in
done
],
dim
=
0
)
dec_states
=
self
.
create_batch_states
(
dec_states
,
[
d
[
1
]
for
d
in
done
])
if
use_lm
:
lm_labels
=
torch
.
LongTensor
([
h
.
yseq
[
-
1
]
for
h
in
hyps
],
device
=
self
.
device
)
return
dec_out
,
dec_states
,
lm_labels
return
dec_out
,
dec_states
,
None
def
select_state
(
self
,
states
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
idx
:
int
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Get specified ID state from decoder hidden states.
Args:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
idx: State ID to extract.
Returns:
: Decoder hidden state for given ID.
((N, 1, D_dec), (N, 1, D_dec))
"""
return
(
states
[
0
][:,
idx
:
idx
+
1
,
:],
states
[
1
][:,
idx
:
idx
+
1
,
:]
if
self
.
dtype
==
"lstm"
else
None
,
)
def
create_batch_states
(
self
,
states
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
new_states
:
List
[
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]],
check_list
:
Optional
[
List
]
=
None
,
)
->
List
[
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]]:
"""Create decoder hidden states.
Args:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec))]
Returns:
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec))
"""
return
(
torch
.
cat
([
s
[
0
]
for
s
in
new_states
],
dim
=
1
),
torch
.
cat
([
s
[
1
]
for
s
in
new_states
],
dim
=
1
)
if
self
.
dtype
==
"lstm"
else
None
,
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/rnn_encoder.py
0 → 100644
View file @
60a2c57a
"""RNN encoder implementation for Transducer model.
These classes are based on the ones in espnet.nets.pytorch_backend.rnn.encoders,
and modified to output intermediate representation based given list of layers as input.
To do so, RNN class rely on a stack of 1-layer LSTM instead of a multi-layer LSTM.
The additional outputs are intended to be used with Transducer auxiliary tasks.
"""
from
argparse
import
Namespace
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch.nn.utils.rnn
import
pack_padded_sequence
,
pad_packed_sequence
from
espnet.nets.e2e_asr_common
import
get_vgg2l_odim
from
espnet.nets.pytorch_backend.nets_utils
import
make_pad_mask
,
to_device
class
RNNP
(
torch
.
nn
.
Module
):
"""RNN with projection layer module.
Args:
idim: Input dimension.
rnn_type: RNNP units type.
elayers: Number of RNNP layers.
eunits: Number of units ((2 * eunits) if bidirectional).
eprojs: Number of projection units.
subsample: Subsampling rate per layer.
dropout_rate: Dropout rate for RNNP layers.
aux_output_layers: Layer IDs for auxiliary RNNP output sequences.
"""
def
__init__
(
self
,
idim
:
int
,
rnn_type
:
str
,
elayers
:
int
,
eunits
:
int
,
eprojs
:
int
,
subsample
:
np
.
ndarray
,
dropout_rate
:
float
,
aux_output_layers
:
List
=
[],
):
"""Initialize RNNP module."""
super
().
__init__
()
bidir
=
rnn_type
[
0
]
==
"b"
for
i
in
range
(
elayers
):
if
i
==
0
:
input_dim
=
idim
else
:
input_dim
=
eprojs
rnn_layer
=
torch
.
nn
.
LSTM
if
"lstm"
in
rnn_type
else
torch
.
nn
.
GRU
rnn
=
rnn_layer
(
input_dim
,
eunits
,
num_layers
=
1
,
bidirectional
=
bidir
,
batch_first
=
True
)
setattr
(
self
,
"%s%d"
%
(
"birnn"
if
bidir
else
"rnn"
,
i
),
rnn
)
if
bidir
:
setattr
(
self
,
"bt%d"
%
i
,
torch
.
nn
.
Linear
(
2
*
eunits
,
eprojs
))
else
:
setattr
(
self
,
"bt%d"
%
i
,
torch
.
nn
.
Linear
(
eunits
,
eprojs
))
self
.
dropout
=
torch
.
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
elayers
=
elayers
self
.
eunits
=
eunits
self
.
subsample
=
subsample
self
.
rnn_type
=
rnn_type
self
.
bidir
=
bidir
self
.
aux_output_layers
=
aux_output_layers
def
forward
(
self
,
rnn_input
:
torch
.
Tensor
,
rnn_len
:
torch
.
Tensor
,
prev_states
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
],
torch
.
Tensor
]:
"""RNNP forward.
Args:
rnn_input: RNN input sequences. (B, T, D_in)
rnn_len: RNN input sequences lengths. (B,)
prev_states: RNN hidden states. [N x (B, T, D_proj)]
Returns:
rnn_output : RNN output sequences. (B, T, D_proj)
with or without intermediate RNN output sequences.
((B, T, D_proj), [N x (B, T, D_proj)])
rnn_len: RNN output sequences lengths. (B,)
current_states: RNN hidden states. [N x (B, T, D_proj)]
"""
aux_rnn_outputs
=
[]
aux_rnn_lens
=
[]
current_states
=
[]
for
layer
in
range
(
self
.
elayers
):
if
not
isinstance
(
rnn_len
,
torch
.
Tensor
):
rnn_len
=
torch
.
tensor
(
rnn_len
)
pack_rnn_input
=
pack_padded_sequence
(
rnn_input
,
rnn_len
.
cpu
(),
batch_first
=
True
)
rnn
=
getattr
(
self
,
(
"birnn"
if
self
.
bidir
else
"rnn"
)
+
str
(
layer
))
if
isinstance
(
rnn
,
(
torch
.
nn
.
LSTM
,
torch
.
nn
.
GRU
)):
rnn
.
flatten_parameters
()
if
prev_states
is
not
None
and
rnn
.
bidirectional
:
prev_states
=
reset_backward_rnn_state
(
prev_states
)
pack_rnn_output
,
states
=
rnn
(
pack_rnn_input
,
hx
=
None
if
prev_states
is
None
else
prev_states
[
layer
]
)
current_states
.
append
(
states
)
pad_rnn_output
,
rnn_len
=
pad_packed_sequence
(
pack_rnn_output
,
batch_first
=
True
)
sub
=
self
.
subsample
[
layer
+
1
]
if
sub
>
1
:
pad_rnn_output
=
pad_rnn_output
[:,
::
sub
]
rnn_len
=
torch
.
tensor
([
int
(
i
+
1
)
//
sub
for
i
in
rnn_len
])
projection_layer
=
getattr
(
self
,
"bt%d"
%
layer
)
proj_rnn_output
=
projection_layer
(
pad_rnn_output
.
contiguous
().
view
(
-
1
,
pad_rnn_output
.
size
(
2
))
)
rnn_output
=
proj_rnn_output
.
view
(
pad_rnn_output
.
size
(
0
),
pad_rnn_output
.
size
(
1
),
-
1
)
if
layer
in
self
.
aux_output_layers
:
aux_rnn_outputs
.
append
(
rnn_output
)
aux_rnn_lens
.
append
(
rnn_len
)
if
layer
<
self
.
elayers
-
1
:
rnn_output
=
torch
.
tanh
(
self
.
dropout
(
rnn_output
))
rnn_input
=
rnn_output
if
aux_rnn_outputs
:
return
(
(
rnn_output
,
aux_rnn_outputs
),
(
rnn_len
,
aux_rnn_lens
),
current_states
,
)
else
:
return
rnn_output
,
rnn_len
,
current_states
class
RNN
(
torch
.
nn
.
Module
):
"""RNN module.
Args:
idim: Input dimension.
rnn_type: RNN units type.
elayers: Number of RNN layers.
eunits: Number of units ((2 * eunits) if bidirectional)
eprojs: Number of final projection units.
dropout_rate: Dropout rate for RNN layers.
aux_output_layers: List of layer IDs for auxiliary RNN output sequences.
"""
def
__init__
(
self
,
idim
:
int
,
rnn_type
:
str
,
elayers
:
int
,
eunits
:
int
,
eprojs
:
int
,
dropout_rate
:
float
,
aux_output_layers
:
List
=
[],
):
"""Initialize RNN module."""
super
().
__init__
()
bidir
=
rnn_type
[
0
]
==
"b"
for
i
in
range
(
elayers
):
if
i
==
0
:
input_dim
=
idim
else
:
input_dim
=
eunits
rnn_layer
=
torch
.
nn
.
LSTM
if
"lstm"
in
rnn_type
else
torch
.
nn
.
GRU
rnn
=
rnn_layer
(
input_dim
,
eunits
,
num_layers
=
1
,
bidirectional
=
bidir
,
batch_first
=
True
)
setattr
(
self
,
"%s%d"
%
(
"birnn"
if
bidir
else
"rnn"
,
i
),
rnn
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
p
=
dropout_rate
)
self
.
elayers
=
elayers
self
.
eunits
=
eunits
self
.
eprojs
=
eprojs
self
.
rnn_type
=
rnn_type
self
.
bidir
=
bidir
self
.
l_last
=
torch
.
nn
.
Linear
(
eunits
,
eprojs
)
self
.
aux_output_layers
=
aux_output_layers
def
forward
(
self
,
rnn_input
:
torch
.
Tensor
,
rnn_len
:
torch
.
Tensor
,
prev_states
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
],
torch
.
Tensor
]:
"""RNN forward.
Args:
rnn_input: RNN input sequences. (B, T, D_in)
rnn_len: RNN input sequences lengths. (B,)
prev_states: RNN hidden states. [N x (B, T, D_proj)]
Returns:
rnn_output : RNN output sequences. (B, T, D_proj)
with or without intermediate RNN output sequences.
((B, T, D_proj), [N x (B, T, D_proj)])
rnn_len: RNN output sequences lengths. (B,)
current_states: RNN hidden states. [N x (B, T, D_proj)]
"""
aux_rnn_outputs
=
[]
aux_rnn_lens
=
[]
current_states
=
[]
for
layer
in
range
(
self
.
elayers
):
if
not
isinstance
(
rnn_len
,
torch
.
Tensor
):
rnn_len
=
torch
.
tensor
(
rnn_len
)
pack_rnn_input
=
pack_padded_sequence
(
rnn_input
,
rnn_len
.
cpu
(),
batch_first
=
True
)
rnn
=
getattr
(
self
,
(
"birnn"
if
self
.
bidir
else
"rnn"
)
+
str
(
layer
))
if
isinstance
(
rnn
,
(
torch
.
nn
.
LSTM
,
torch
.
nn
.
GRU
)):
rnn
.
flatten_parameters
()
if
prev_states
is
not
None
and
rnn
.
bidirectional
:
prev_states
=
reset_backward_rnn_state
(
prev_states
)
pack_rnn_output
,
states
=
rnn
(
pack_rnn_input
,
hx
=
None
if
prev_states
is
None
else
prev_states
[
layer
]
)
current_states
.
append
(
states
)
rnn_output
,
rnn_len
=
pad_packed_sequence
(
pack_rnn_output
,
batch_first
=
True
)
if
self
.
bidir
:
rnn_output
=
(
rnn_output
[:,
:,
:
self
.
eunits
]
+
rnn_output
[:,
:,
self
.
eunits
:]
)
if
layer
in
self
.
aux_output_layers
:
aux_proj_rnn_output
=
torch
.
tanh
(
self
.
l_last
(
rnn_output
.
contiguous
().
view
(
-
1
,
rnn_output
.
size
(
2
)))
)
aux_rnn_output
=
aux_proj_rnn_output
.
view
(
rnn_output
.
size
(
0
),
rnn_output
.
size
(
1
),
-
1
)
aux_rnn_outputs
.
append
(
aux_rnn_output
)
aux_rnn_lens
.
append
(
rnn_len
)
if
layer
<
self
.
elayers
-
1
:
rnn_input
=
self
.
dropout
(
rnn_output
)
proj_rnn_output
=
torch
.
tanh
(
self
.
l_last
(
rnn_output
.
contiguous
().
view
(
-
1
,
rnn_output
.
size
(
2
)))
)
rnn_output
=
proj_rnn_output
.
view
(
rnn_output
.
size
(
0
),
rnn_output
.
size
(
1
),
-
1
)
if
aux_rnn_outputs
:
return
(
(
rnn_output
,
aux_rnn_outputs
),
(
rnn_len
,
aux_rnn_lens
),
current_states
,
)
else
:
return
rnn_output
,
rnn_len
,
current_states
def
reset_backward_rnn_state
(
states
:
Union
[
torch
.
Tensor
,
List
[
Optional
[
torch
.
Tensor
]]]
)
->
Union
[
torch
.
Tensor
,
List
[
Optional
[
torch
.
Tensor
]]]:
"""Set backward BRNN states to zeroes.
Args:
states: Encoder hidden states.
Returns:
states: Encoder hidden states with backward set to zero.
"""
if
isinstance
(
states
,
list
):
for
state
in
states
:
state
[
1
::
2
]
=
0.0
else
:
states
[
1
::
2
]
=
0.0
return
states
class
VGG2L
(
torch
.
nn
.
Module
):
"""VGG-like module.
Args:
in_channel: number of input channels
"""
def
__init__
(
self
,
in_channel
:
int
=
1
):
"""Initialize VGG-like module."""
super
(
VGG2L
,
self
).
__init__
()
# CNN layer (VGG motivated)
self
.
conv1_1
=
torch
.
nn
.
Conv2d
(
in_channel
,
64
,
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1_2
=
torch
.
nn
.
Conv2d
(
64
,
64
,
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2_1
=
torch
.
nn
.
Conv2d
(
64
,
128
,
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2_2
=
torch
.
nn
.
Conv2d
(
128
,
128
,
3
,
stride
=
1
,
padding
=
1
)
self
.
in_channel
=
in_channel
def
forward
(
self
,
feats
:
torch
.
Tensor
,
feats_len
:
torch
.
Tensor
,
**
kwargs
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""VGG2L forward.
Args:
feats: Feature sequences. (B, F, D_feats)
feats_len: Feature sequences lengths. (B, )
Returns:
vgg_out: VGG2L output sequences. (B, F // 4, 128 * D_feats // 4)
vgg_out_len: VGG2L output sequences lengths. (B,)
"""
feats
=
feats
.
view
(
feats
.
size
(
0
),
feats
.
size
(
1
),
self
.
in_channel
,
feats
.
size
(
2
)
//
self
.
in_channel
,
).
transpose
(
1
,
2
)
vgg1
=
F
.
relu
(
self
.
conv1_1
(
feats
))
vgg1
=
F
.
relu
(
self
.
conv1_2
(
vgg1
))
vgg1
=
F
.
max_pool2d
(
vgg1
,
2
,
stride
=
2
,
ceil_mode
=
True
)
vgg2
=
F
.
relu
(
self
.
conv2_1
(
vgg1
))
vgg2
=
F
.
relu
(
self
.
conv2_2
(
vgg2
))
vgg2
=
F
.
max_pool2d
(
vgg2
,
2
,
stride
=
2
,
ceil_mode
=
True
)
vgg_out
=
vgg2
.
transpose
(
1
,
2
)
vgg_out
=
vgg_out
.
contiguous
().
view
(
vgg_out
.
size
(
0
),
vgg_out
.
size
(
1
),
vgg_out
.
size
(
2
)
*
vgg_out
.
size
(
3
)
)
if
torch
.
is_tensor
(
feats_len
):
feats_len
=
feats_len
.
cpu
().
numpy
()
else
:
feats_len
=
np
.
array
(
feats_len
,
dtype
=
np
.
float32
)
vgg1_len
=
np
.
array
(
np
.
ceil
(
feats_len
/
2
),
dtype
=
np
.
int64
)
vgg_out_len
=
np
.
array
(
np
.
ceil
(
np
.
array
(
vgg1_len
,
dtype
=
np
.
float32
)
/
2
),
dtype
=
np
.
int64
).
tolist
()
return
vgg_out
,
vgg_out_len
,
None
class
Encoder
(
torch
.
nn
.
Module
):
"""Encoder module.
Args:
idim: Input dimension.
etype: Encoder units type.
elayers: Number of encoder layers.
eunits: Number of encoder units per layer.
eprojs: Number of projection units per layer.
subsample: Subsampling rate per layer.
dropout_rate: Dropout rate for encoder layers.
intermediate_encoder_layers: Layer IDs for auxiliary encoder output sequences.
"""
def
__init__
(
self
,
idim
:
int
,
etype
:
str
,
elayers
:
int
,
eunits
:
int
,
eprojs
:
int
,
subsample
:
np
.
ndarray
,
dropout_rate
:
float
=
0.0
,
aux_enc_output_layers
:
List
=
[],
):
"""Initialize Encoder module."""
super
(
Encoder
,
self
).
__init__
()
rnn_type
=
etype
.
lstrip
(
"vgg"
).
rstrip
(
"p"
)
in_channel
=
1
if
etype
.
startswith
(
"vgg"
):
if
etype
[
-
1
]
==
"p"
:
self
.
enc
=
torch
.
nn
.
ModuleList
(
[
VGG2L
(
in_channel
),
RNNP
(
get_vgg2l_odim
(
idim
,
in_channel
=
in_channel
),
rnn_type
,
elayers
,
eunits
,
eprojs
,
subsample
,
dropout_rate
=
dropout_rate
,
aux_output_layers
=
aux_enc_output_layers
,
),
]
)
else
:
self
.
enc
=
torch
.
nn
.
ModuleList
(
[
VGG2L
(
in_channel
),
RNN
(
get_vgg2l_odim
(
idim
,
in_channel
=
in_channel
),
rnn_type
,
elayers
,
eunits
,
eprojs
,
dropout_rate
=
dropout_rate
,
aux_output_layers
=
aux_enc_output_layers
,
),
]
)
self
.
conv_subsampling_factor
=
4
else
:
if
etype
[
-
1
]
==
"p"
:
self
.
enc
=
torch
.
nn
.
ModuleList
(
[
RNNP
(
idim
,
rnn_type
,
elayers
,
eunits
,
eprojs
,
subsample
,
dropout_rate
=
dropout_rate
,
aux_output_layers
=
aux_enc_output_layers
,
)
]
)
else
:
self
.
enc
=
torch
.
nn
.
ModuleList
(
[
RNN
(
idim
,
rnn_type
,
elayers
,
eunits
,
eprojs
,
dropout_rate
=
dropout_rate
,
aux_output_layers
=
aux_enc_output_layers
,
)
]
)
self
.
conv_subsampling_factor
=
1
def
forward
(
self
,
feats
:
torch
.
Tensor
,
feats_len
:
torch
.
Tensor
,
prev_states
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
):
"""Forward encoder.
Args:
feats: Feature sequences. (B, F, D_feats)
feats_len: Feature sequences lengths. (B,)
prev_states: Previous encoder hidden states. [N x (B, T, D_enc)]
Returns:
enc_out: Encoder output sequences. (B, T, D_enc)
with or without encoder intermediate output sequences.
((B, T, D_enc), [N x (B, T, D_enc)])
enc_out_len: Encoder output sequences lengths. (B,)
current_states: Encoder hidden states. [N x (B, T, D_enc)]
"""
if
prev_states
is
None
:
prev_states
=
[
None
]
*
len
(
self
.
enc
)
assert
len
(
prev_states
)
==
len
(
self
.
enc
)
_enc_out
=
feats
_enc_out_len
=
feats_len
current_states
=
[]
for
rnn_module
,
prev_state
in
zip
(
self
.
enc
,
prev_states
):
_enc_out
,
_enc_out_len
,
states
=
rnn_module
(
_enc_out
,
_enc_out_len
,
prev_states
=
prev_state
,
)
current_states
.
append
(
states
)
if
isinstance
(
_enc_out
,
tuple
):
enc_out
,
aux_enc_out
=
_enc_out
[
0
],
_enc_out
[
1
]
enc_out_len
,
aux_enc_out_len
=
_enc_out_len
[
0
],
_enc_out_len
[
1
]
enc_out_mask
=
to_device
(
enc_out
,
make_pad_mask
(
enc_out_len
).
unsqueeze
(
-
1
))
enc_out
=
enc_out
.
masked_fill
(
enc_out_mask
,
0.0
)
for
i
in
range
(
len
(
aux_enc_out
)):
aux_mask
=
to_device
(
aux_enc_out
[
i
],
make_pad_mask
(
aux_enc_out_len
[
i
]).
unsqueeze
(
-
1
)
)
aux_enc_out
[
i
]
=
aux_enc_out
[
i
].
masked_fill
(
aux_mask
,
0.0
)
return
(
(
enc_out
,
aux_enc_out
),
(
enc_out_len
,
aux_enc_out_len
),
current_states
,
)
else
:
enc_out_mask
=
to_device
(
_enc_out
,
make_pad_mask
(
_enc_out_len
).
unsqueeze
(
-
1
)
)
return
_enc_out
.
masked_fill
(
enc_out_mask
,
0.0
),
_enc_out_len
,
current_states
def
encoder_for
(
args
:
Namespace
,
idim
:
int
,
subsample
:
np
.
ndarray
,
aux_enc_output_layers
:
List
=
[],
)
->
torch
.
nn
.
Module
:
"""Instantiate a RNN encoder with specified arguments.
Args:
args: The model arguments.
idim: Input dimension.
subsample: Subsampling rate per layer.
aux_enc_output_layers: Layer IDs for auxiliary encoder output sequences.
Returns:
: Encoder module.
"""
return
Encoder
(
idim
,
args
.
etype
,
args
.
elayers
,
args
.
eunits
,
args
.
eprojs
,
subsample
,
dropout_rate
=
args
.
dropout_rate
,
aux_enc_output_layers
=
aux_enc_output_layers
,
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/transducer_tasks.py
0 → 100644
View file @
60a2c57a
"""Module implementing Transducer main and auxiliary tasks."""
from
typing
import
Any
,
List
,
Optional
,
Tuple
import
torch
from
espnet.nets.pytorch_backend.nets_utils
import
pad_list
from
espnet.nets.pytorch_backend.transducer.joint_network
import
JointNetwork
from
espnet.nets.pytorch_backend.transformer.label_smoothing_loss
import
(
# noqa: H301
LabelSmoothingLoss
,
)
class
TransducerTasks
(
torch
.
nn
.
Module
):
"""Transducer tasks module."""
def
__init__
(
self
,
encoder_dim
:
int
,
decoder_dim
:
int
,
joint_dim
:
int
,
output_dim
:
int
,
joint_activation_type
:
str
=
"tanh"
,
transducer_loss_weight
:
float
=
1.0
,
ctc_loss
:
bool
=
False
,
ctc_loss_weight
:
float
=
0.5
,
ctc_loss_dropout_rate
:
float
=
0.0
,
lm_loss
:
bool
=
False
,
lm_loss_weight
:
float
=
0.5
,
lm_loss_smoothing_rate
:
float
=
0.0
,
aux_transducer_loss
:
bool
=
False
,
aux_transducer_loss_weight
:
float
=
0.2
,
aux_transducer_loss_mlp_dim
:
int
=
320
,
aux_trans_loss_mlp_dropout_rate
:
float
=
0.0
,
symm_kl_div_loss
:
bool
=
False
,
symm_kl_div_loss_weight
:
float
=
0.2
,
fastemit_lambda
:
float
=
0.0
,
blank_id
:
int
=
0
,
ignore_id
:
int
=
-
1
,
training
:
bool
=
False
,
):
"""Initialize module for Transducer tasks.
Args:
encoder_dim: Encoder outputs dimension.
decoder_dim: Decoder outputs dimension.
joint_dim: Joint space dimension.
output_dim: Output dimension.
joint_activation_type: Type of activation for joint network.
transducer_loss_weight: Weight for main transducer loss.
ctc_loss: Compute CTC loss.
ctc_loss_weight: Weight of CTC loss.
ctc_loss_dropout_rate: Dropout rate for CTC loss inputs.
lm_loss: Compute LM loss.
lm_loss_weight: Weight of LM loss.
lm_loss_smoothing_rate: Smoothing rate for LM loss' label smoothing.
aux_transducer_loss: Compute auxiliary transducer loss.
aux_transducer_loss_weight: Weight of auxiliary transducer loss.
aux_transducer_loss_mlp_dim: Hidden dimension for aux. transducer MLP.
aux_trans_loss_mlp_dropout_rate: Dropout rate for aux. transducer MLP.
symm_kl_div_loss: Compute KL divergence loss.
symm_kl_div_loss_weight: Weight of KL divergence loss.
fastemit_lambda: Regularization parameter for FastEmit.
blank_id: Blank symbol ID.
ignore_id: Padding symbol ID.
training: Whether the model was initializated in training or inference mode.
"""
super
().
__init__
()
if
not
training
:
ctc_loss
,
lm_loss
,
aux_transducer_loss
,
symm_kl_div_loss
=
(
False
,
False
,
False
,
False
,
)
self
.
joint_network
=
JointNetwork
(
output_dim
,
encoder_dim
,
decoder_dim
,
joint_dim
,
joint_activation_type
)
if
training
:
from
warprnnt_pytorch
import
RNNTLoss
self
.
transducer_loss
=
RNNTLoss
(
blank
=
blank_id
,
reduction
=
"sum"
,
fastemit_lambda
=
fastemit_lambda
,
)
if
ctc_loss
:
self
.
ctc_lin
=
torch
.
nn
.
Linear
(
encoder_dim
,
output_dim
)
self
.
ctc_loss
=
torch
.
nn
.
CTCLoss
(
blank
=
blank_id
,
reduction
=
"none"
,
zero_infinity
=
True
,
)
if
aux_transducer_loss
:
self
.
mlp
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
encoder_dim
,
aux_transducer_loss_mlp_dim
),
torch
.
nn
.
LayerNorm
(
aux_transducer_loss_mlp_dim
),
torch
.
nn
.
Dropout
(
p
=
aux_trans_loss_mlp_dropout_rate
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Linear
(
aux_transducer_loss_mlp_dim
,
joint_dim
),
)
if
symm_kl_div_loss
:
self
.
kl_div
=
torch
.
nn
.
KLDivLoss
(
reduction
=
"sum"
)
if
lm_loss
:
self
.
lm_lin
=
torch
.
nn
.
Linear
(
decoder_dim
,
output_dim
)
self
.
label_smoothing_loss
=
LabelSmoothingLoss
(
output_dim
,
ignore_id
,
lm_loss_smoothing_rate
,
normalize_length
=
False
)
self
.
output_dim
=
output_dim
self
.
transducer_loss_weight
=
transducer_loss_weight
self
.
use_ctc_loss
=
ctc_loss
self
.
ctc_loss_weight
=
ctc_loss_weight
self
.
ctc_dropout_rate
=
ctc_loss_dropout_rate
self
.
use_lm_loss
=
lm_loss
self
.
lm_loss_weight
=
lm_loss_weight
self
.
use_aux_transducer_loss
=
aux_transducer_loss
self
.
aux_transducer_loss_weight
=
aux_transducer_loss_weight
self
.
use_symm_kl_div_loss
=
symm_kl_div_loss
self
.
symm_kl_div_loss_weight
=
symm_kl_div_loss_weight
self
.
blank_id
=
blank_id
self
.
ignore_id
=
ignore_id
self
.
target
=
None
def
compute_transducer_loss
(
self
,
enc_out
:
torch
.
Tensor
,
dec_out
:
torch
.
tensor
,
target
:
torch
.
Tensor
,
t_len
:
torch
.
Tensor
,
u_len
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute Transducer loss.
Args:
enc_out: Encoder output sequences. (B, T, D_enc)
dec_out: Decoder output sequences. (B, U, D_dec)
target: Target label ID sequences. (B, L)
t_len: Time lengths. (B,)
u_len: Label lengths. (B,)
Returns:
(joint_out, loss_trans):
Joint output sequences. (B, T, U, D_joint),
Transducer loss value.
"""
joint_out
=
self
.
joint_network
(
enc_out
.
unsqueeze
(
2
),
dec_out
.
unsqueeze
(
1
))
loss_trans
=
self
.
transducer_loss
(
joint_out
,
target
,
t_len
,
u_len
)
loss_trans
/=
joint_out
.
size
(
0
)
return
joint_out
,
loss_trans
def
compute_ctc_loss
(
self
,
enc_out
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
t_len
:
torch
.
Tensor
,
u_len
:
torch
.
Tensor
,
):
"""Compute CTC loss.
Args:
enc_out: Encoder output sequences. (B, T, D_enc)
target: Target character ID sequences. (B, U)
t_len: Time lengths. (B,)
u_len: Label lengths. (B,)
Returns:
: CTC loss value.
"""
ctc_lin
=
self
.
ctc_lin
(
torch
.
nn
.
functional
.
dropout
(
enc_out
.
to
(
dtype
=
torch
.
float32
),
p
=
self
.
ctc_dropout_rate
)
)
ctc_logp
=
torch
.
log_softmax
(
ctc_lin
.
transpose
(
0
,
1
),
dim
=-
1
)
with
torch
.
backends
.
cudnn
.
flags
(
deterministic
=
True
):
loss_ctc
=
self
.
ctc_loss
(
ctc_logp
,
target
,
t_len
,
u_len
)
return
loss_ctc
.
mean
()
def
compute_aux_transducer_and_symm_kl_div_losses
(
self
,
aux_enc_out
:
torch
.
Tensor
,
dec_out
:
torch
.
Tensor
,
joint_out
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
aux_t_len
:
torch
.
Tensor
,
u_len
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute auxiliary Transducer loss and Jensen-Shannon divergence loss.
Args:
aux_enc_out: Encoder auxiliary output sequences. [N x (B, T_aux, D_enc_aux)]
dec_out: Decoder output sequences. (B, U, D_dec)
joint_out: Joint output sequences. (B, T, U, D_joint)
target: Target character ID sequences. (B, L)
aux_t_len: Auxiliary time lengths. [N x (B,)]
u_len: True U lengths. (B,)
Returns:
: Auxiliary Transducer loss and KL divergence loss values.
"""
aux_trans_loss
=
0
symm_kl_div_loss
=
0
num_aux_layers
=
len
(
aux_enc_out
)
B
,
T
,
U
,
D
=
joint_out
.
shape
for
p
in
self
.
joint_network
.
parameters
():
p
.
requires_grad
=
False
for
i
,
aux_enc_out_i
in
enumerate
(
aux_enc_out
):
aux_mlp
=
self
.
mlp
(
aux_enc_out_i
)
aux_joint_out
=
self
.
joint_network
(
aux_mlp
.
unsqueeze
(
2
),
dec_out
.
unsqueeze
(
1
),
is_aux
=
True
,
)
if
self
.
use_aux_transducer_loss
:
aux_trans_loss
+=
(
self
.
transducer_loss
(
aux_joint_out
,
target
,
aux_t_len
[
i
],
u_len
,
)
/
B
)
if
self
.
use_symm_kl_div_loss
:
denom
=
B
*
T
*
U
kl_main_aux
=
(
self
.
kl_div
(
torch
.
log_softmax
(
joint_out
,
dim
=-
1
),
torch
.
softmax
(
aux_joint_out
,
dim
=-
1
),
)
/
denom
)
kl_aux_main
=
(
self
.
kl_div
(
torch
.
log_softmax
(
aux_joint_out
,
dim
=-
1
),
torch
.
softmax
(
joint_out
,
dim
=-
1
),
)
/
denom
)
symm_kl_div_loss
+=
kl_main_aux
+
kl_aux_main
for
p
in
self
.
joint_network
.
parameters
():
p
.
requires_grad
=
True
aux_trans_loss
/=
num_aux_layers
if
self
.
use_symm_kl_div_loss
:
symm_kl_div_loss
/=
num_aux_layers
return
aux_trans_loss
,
symm_kl_div_loss
def
compute_lm_loss
(
self
,
dec_out
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Forward LM loss.
Args:
dec_out: Decoder output sequences. (B, U, D_dec)
target: Target label ID sequences. (B, U)
Returns:
: LM loss value.
"""
lm_lin
=
self
.
lm_lin
(
dec_out
)
lm_loss
=
self
.
label_smoothing_loss
(
lm_lin
,
target
)
return
lm_loss
def
set_target
(
self
,
target
:
torch
.
Tensor
):
"""Set target label ID sequences.
Args:
target: Target label ID sequences. (B, L)
"""
self
.
target
=
target
def
get_target
(
self
):
"""Set target label ID sequences.
Args:
Returns:
target: Target label ID sequences. (B, L)
"""
return
self
.
target
def
get_transducer_tasks_io
(
self
,
labels
:
torch
.
Tensor
,
enc_out_len
:
torch
.
Tensor
,
aux_enc_out_len
:
Optional
[
List
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Get Transducer tasks inputs and outputs.
Args:
labels: Label ID sequences. (B, U)
enc_out_len: Time lengths. (B,)
aux_enc_out_len: Auxiliary time lengths. [N X (B,)]
Returns:
target: Target label ID sequences. (B, L)
lm_loss_target: LM loss target label ID sequences. (B, U)
t_len: Time lengths. (B,)
aux_t_len: Auxiliary time lengths. [N x (B,)]
u_len: Label lengths. (B,)
"""
device
=
labels
.
device
labels_unpad
=
[
label
[
label
!=
self
.
ignore_id
]
for
label
in
labels
]
blank
=
labels
[
0
].
new
([
self
.
blank_id
])
target
=
pad_list
(
labels_unpad
,
self
.
blank_id
).
type
(
torch
.
int32
).
to
(
device
)
lm_loss_target
=
(
pad_list
(
[
torch
.
cat
([
y
,
blank
],
dim
=
0
)
for
y
in
labels_unpad
],
self
.
ignore_id
)
.
type
(
torch
.
int64
)
.
to
(
device
)
)
self
.
set_target
(
target
)
if
enc_out_len
.
dim
()
>
1
:
enc_mask_unpad
=
[
m
[
m
!=
0
]
for
m
in
enc_out_len
]
enc_out_len
=
list
(
map
(
int
,
[
m
.
size
(
0
)
for
m
in
enc_mask_unpad
]))
else
:
enc_out_len
=
list
(
map
(
int
,
enc_out_len
))
t_len
=
torch
.
IntTensor
(
enc_out_len
).
to
(
device
)
u_len
=
torch
.
IntTensor
([
label
.
size
(
0
)
for
label
in
labels_unpad
]).
to
(
device
)
if
aux_enc_out_len
:
aux_t_len
=
[]
for
i
in
range
(
len
(
aux_enc_out_len
)):
if
aux_enc_out_len
[
i
].
dim
()
>
1
:
aux_mask_unpad
=
[
aux
[
aux
!=
0
]
for
aux
in
aux_enc_out_len
[
i
]]
aux_t_len
.
append
(
torch
.
IntTensor
(
list
(
map
(
int
,
[
aux
.
size
(
0
)
for
aux
in
aux_mask_unpad
]))
).
to
(
device
)
)
else
:
aux_t_len
.
append
(
torch
.
IntTensor
(
list
(
map
(
int
,
aux_enc_out_len
[
i
]))).
to
(
device
)
)
else
:
aux_t_len
=
aux_enc_out_len
return
target
,
lm_loss_target
,
t_len
,
aux_t_len
,
u_len
def
forward
(
self
,
enc_out
:
torch
.
Tensor
,
aux_enc_out
:
List
[
torch
.
Tensor
],
dec_out
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
,
enc_out_len
:
torch
.
Tensor
,
aux_enc_out_len
:
torch
.
Tensor
,
)
->
Tuple
[
Tuple
[
Any
],
float
,
float
]:
"""Forward main and auxiliary task.
Args:
enc_out: Encoder output sequences. (B, T, D_enc)
aux_enc_out: Encoder intermediate output sequences. (B, T_aux, D_enc_aux)
dec_out: Decoder output sequences. (B, U, D_dec)
target: Target label ID sequences. (B, L)
t_len: Time lengths. (B,)
aux_t_len: Auxiliary time lengths. (B,)
u_len: Label lengths. (B,)
Returns:
: Weighted losses.
(transducer loss, ctc loss, aux Transducer loss, KL div loss, LM loss)
cer: Sentence-level CER score.
wer: Sentence-level WER score.
"""
if
self
.
use_symm_kl_div_loss
:
assert
self
.
use_aux_transducer_loss
(
trans_loss
,
ctc_loss
,
lm_loss
,
aux_trans_loss
,
symm_kl_div_loss
)
=
(
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
)
target
,
lm_loss_target
,
t_len
,
aux_t_len
,
u_len
=
self
.
get_transducer_tasks_io
(
labels
,
enc_out_len
,
aux_enc_out_len
,
)
joint_out
,
trans_loss
=
self
.
compute_transducer_loss
(
enc_out
,
dec_out
,
target
,
t_len
,
u_len
)
if
self
.
use_ctc_loss
:
ctc_loss
=
self
.
compute_ctc_loss
(
enc_out
,
target
,
t_len
,
u_len
)
if
self
.
use_aux_transducer_loss
:
(
aux_trans_loss
,
symm_kl_div_loss
,
)
=
self
.
compute_aux_transducer_and_symm_kl_div_losses
(
aux_enc_out
,
dec_out
,
joint_out
,
target
,
aux_t_len
,
u_len
,
)
if
self
.
use_lm_loss
:
lm_loss
=
self
.
compute_lm_loss
(
dec_out
,
lm_loss_target
)
return
(
self
.
transducer_loss_weight
*
trans_loss
,
self
.
ctc_loss_weight
*
ctc_loss
,
self
.
aux_transducer_loss_weight
*
aux_trans_loss
,
self
.
symm_kl_div_loss_weight
*
symm_kl_div_loss
,
self
.
lm_loss_weight
*
lm_loss
,
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transducer/transformer_decoder_layer.py
0 → 100644
View file @
60a2c57a
"""Transformer decoder layer definition for custom Transducer model."""
from
typing
import
Optional
import
torch
from
espnet.nets.pytorch_backend.transformer.attention
import
MultiHeadedAttention
from
espnet.nets.pytorch_backend.transformer.layer_norm
import
LayerNorm
from
espnet.nets.pytorch_backend.transformer.positionwise_feed_forward
import
(
PositionwiseFeedForward
,
)
class
TransformerDecoderLayer
(
torch
.
nn
.
Module
):
"""Transformer decoder layer module for custom Transducer model.
Args:
hdim: Hidden dimension.
self_attention: Self-attention module.
feed_forward: Feed forward module.
dropout_rate: Dropout rate.
"""
def
__init__
(
self
,
hdim
:
int
,
self_attention
:
MultiHeadedAttention
,
feed_forward
:
PositionwiseFeedForward
,
dropout_rate
:
float
,
):
"""Construct an DecoderLayer object."""
super
().
__init__
()
self
.
self_attention
=
self_attention
self
.
feed_forward
=
feed_forward
self
.
norm1
=
LayerNorm
(
hdim
)
self
.
norm2
=
LayerNorm
(
hdim
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout_rate
)
self
.
hdim
=
hdim
def
forward
(
self
,
sequence
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
cache
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""Compute previous decoder output sequences.
Args:
sequence: Transformer input sequences. (B, U, D_dec)
mask: Transformer intput mask sequences. (B, U)
cache: Cached decoder output sequences. (B, (U - 1), D_dec)
Returns:
sequence: Transformer output sequences. (B, U, D_dec)
mask: Transformer output mask sequences. (B, U)
"""
residual
=
sequence
sequence
=
self
.
norm1
(
sequence
)
if
cache
is
None
:
sequence_q
=
sequence
else
:
batch
=
sequence
.
shape
[
0
]
prev_len
=
sequence
.
shape
[
1
]
-
1
assert
cache
.
shape
==
(
batch
,
prev_len
,
self
.
hdim
,
),
f
"
{
cache
.
shape
}
==
{
(
batch
,
prev_len
,
self
.
hdim
)
}
"
sequence_q
=
sequence
[:,
-
1
:,
:]
residual
=
residual
[:,
-
1
:,
:]
if
mask
is
not
None
:
mask
=
mask
[:,
-
1
:,
:]
sequence
=
residual
+
self
.
dropout
(
self
.
self_attention
(
sequence_q
,
sequence
,
sequence
,
mask
)
)
residual
=
sequence
sequence
=
self
.
norm2
(
sequence
)
sequence
=
residual
+
self
.
dropout
(
self
.
feed_forward
(
sequence
))
if
cache
is
not
None
:
sequence
=
torch
.
cat
([
cache
,
sequence
],
dim
=
1
)
return
sequence
,
mask
Prev
1
…
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment