Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
sunzhq2
yidong-infer
Commits
60a2c57a
Commit
60a2c57a
authored
Jan 27, 2026
by
sunzhq2
Committed by
xuxo
Jan 27, 2026
Browse files
update conformer
parent
4a699441
Changes
216
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4273 additions
and
0 deletions
+4273
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/transformer/subsampling.py
...ib/espnet/nets/chainer_backend/transformer/subsampling.py
+115
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/transformer/training.py
...d/lib/espnet/nets/chainer_backend/transformer/training.py
+321
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/ctc_prefix_score.py
...202304_20240621/build/lib/espnet/nets/ctc_prefix_score.py
+357
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/e2e_asr_common.py
...v.202304_20240621/build/lib/espnet/nets/e2e_asr_common.py
+253
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/e2e_mt_common.py
...-v.202304_20240621/build/lib/espnet/nets/e2e_mt_common.py
+74
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/lm_interface.py
...t-v.202304_20240621/build/lib/espnet/nets/lm_interface.py
+86
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/mt_interface.py
...t-v.202304_20240621/build/lib/espnet/nets/mt_interface.py
+94
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/__init__.py
...0240621/build/lib/espnet/nets/pytorch_backend/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/__init__.py
...ild/lib/espnet/nets/pytorch_backend/conformer/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/argument.py
...ild/lib/espnet/nets/pytorch_backend/conformer/argument.py
+87
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/contextual_block_encoder_layer.py
...torch_backend/conformer/contextual_block_encoder_layer.py
+310
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/convolution.py
.../lib/espnet/nets/pytorch_backend/conformer/convolution.py
+79
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/encoder.py
...uild/lib/espnet/nets/pytorch_backend/conformer/encoder.py
+299
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/encoder_layer.py
...ib/espnet/nets/pytorch_backend/conformer/encoder_layer.py
+179
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/swish.py
.../build/lib/espnet/nets/pytorch_backend/conformer/swish.py
+18
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/ctc.py
...304_20240621/build/lib/espnet/nets/pytorch_backend/ctc.py
+268
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/e2e_asr.py
...20240621/build/lib/espnet/nets/pytorch_backend/e2e_asr.py
+545
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/e2e_asr_conformer.py
...uild/lib/espnet/nets/pytorch_backend/e2e_asr_conformer.py
+80
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/e2e_asr_maskctc.py
.../build/lib/espnet/nets/pytorch_backend/e2e_asr_maskctc.py
+277
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/e2e_asr_mix.py
...0621/build/lib/espnet/nets/pytorch_backend/e2e_asr_mix.py
+829
-0
No files found.
Too many changes to show.
To preserve performance only
216 of 216+
files are displayed.
Plain diff
Email patch
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/transformer/subsampling.py
0 → 100644
View file @
60a2c57a
# encoding: utf-8
"""Class Declaration of Transformer's Input layers."""
import
logging
import
chainer
import
chainer.functions
as
F
import
chainer.links
as
L
import
numpy
as
np
from
espnet.nets.chainer_backend.transformer.embedding
import
PositionalEncoding
class
Conv2dSubsampling
(
chainer
.
Chain
):
"""Convolutional 2D subsampling (to 1/4 length).
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
"""
def
__init__
(
self
,
channels
,
idim
,
dims
,
dropout
=
0.1
,
initialW
=
None
,
initial_bias
=
None
):
"""Initialize Conv2dSubsampling."""
super
(
Conv2dSubsampling
,
self
).
__init__
()
self
.
dropout
=
dropout
with
self
.
init_scope
():
# Standard deviation for Conv2D with 1 channel and kernel 3 x 3.
n
=
1
*
3
*
3
stvd
=
1.0
/
np
.
sqrt
(
n
)
self
.
conv1
=
L
.
Convolution2D
(
1
,
channels
,
3
,
stride
=
2
,
pad
=
1
,
initialW
=
initialW
(
scale
=
stvd
),
initial_bias
=
initial_bias
(
scale
=
stvd
),
)
n
=
channels
*
3
*
3
stvd
=
1.0
/
np
.
sqrt
(
n
)
self
.
conv2
=
L
.
Convolution2D
(
channels
,
channels
,
3
,
stride
=
2
,
pad
=
1
,
initialW
=
initialW
(
scale
=
stvd
),
initial_bias
=
initial_bias
(
scale
=
stvd
),
)
stvd
=
1.0
/
np
.
sqrt
(
dims
)
self
.
out
=
L
.
Linear
(
idim
,
dims
,
initialW
=
initialW
(
scale
=
stvd
),
initial_bias
=
initial_bias
(
scale
=
stvd
),
)
self
.
pe
=
PositionalEncoding
(
dims
,
dropout
)
def
forward
(
self
,
xs
,
ilens
):
"""Subsample x.
:param chainer.Variable x: input tensor
:return: subsampled x and mask
"""
xs
=
self
.
xp
.
array
(
xs
[:,
None
])
xs
=
F
.
relu
(
self
.
conv1
(
xs
))
xs
=
F
.
relu
(
self
.
conv2
(
xs
))
batch
,
_
,
length
,
_
=
xs
.
shape
xs
=
self
.
out
(
F
.
swapaxes
(
xs
,
1
,
2
).
reshape
(
batch
*
length
,
-
1
))
xs
=
self
.
pe
(
xs
.
reshape
(
batch
,
length
,
-
1
))
# change ilens accordingly
ilens
=
np
.
ceil
(
np
.
array
(
ilens
,
dtype
=
np
.
float32
)
/
2
).
astype
(
np
.
int64
)
ilens
=
np
.
ceil
(
np
.
array
(
ilens
,
dtype
=
np
.
float32
)
/
2
).
astype
(
np
.
int64
)
return
xs
,
ilens
class
LinearSampling
(
chainer
.
Chain
):
"""Linear 1D subsampling.
:param int idim: input dim
:param int odim: output dim
:param flaot dropout_rate: dropout rate
"""
def
__init__
(
self
,
idim
,
dims
,
dropout
=
0.1
,
initialW
=
None
,
initial_bias
=
None
):
"""Initialize LinearSampling."""
super
(
LinearSampling
,
self
).
__init__
()
stvd
=
1.0
/
np
.
sqrt
(
dims
)
self
.
dropout
=
dropout
with
self
.
init_scope
():
self
.
linear
=
L
.
Linear
(
idim
,
dims
,
initialW
=
initialW
(
scale
=
stvd
),
initial_bias
=
initial_bias
(
scale
=
stvd
),
)
self
.
pe
=
PositionalEncoding
(
dims
,
dropout
)
def
forward
(
self
,
xs
,
ilens
):
"""Subsample x.
:param chainer.Variable x: input tensor
:return: subsampled x and mask
"""
logging
.
info
(
xs
.
shape
)
xs
=
self
.
linear
(
xs
,
n_batch_axes
=
2
)
logging
.
info
(
xs
.
shape
)
xs
=
self
.
pe
(
xs
)
return
xs
,
ilens
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/transformer/training.py
0 → 100644
View file @
60a2c57a
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Class Declaration of Transformer's Training Subprocess."""
import
collections
import
logging
import
math
import
numpy
as
np
from
chainer
import
cuda
from
chainer
import
functions
as
F
from
chainer
import
training
from
chainer.training
import
extension
from
chainer.training.updaters.multiprocess_parallel_updater
import
(
gather_grads
,
gather_params
,
scatter_grads
,
)
# copied from https://github.com/chainer/chainer/blob/master/chainer/optimizer.py
def
sum_sqnorm
(
arr
):
"""Calculate the norm of the array.
Args:
arr (numpy.ndarray)
Returns:
Float: Sum of the norm calculated from the given array.
"""
sq_sum
=
collections
.
defaultdict
(
float
)
for
x
in
arr
:
with
cuda
.
get_device_from_array
(
x
)
as
dev
:
if
x
is
not
None
:
x
=
x
.
ravel
()
s
=
x
.
dot
(
x
)
sq_sum
[
int
(
dev
)]
+=
s
return
sum
([
float
(
i
)
for
i
in
sq_sum
.
values
()])
class
CustomUpdater
(
training
.
StandardUpdater
):
"""Custom updater for chainer.
Args:
train_iter (iterator | dict[str, iterator]): Dataset iterator for the
training dataset. It can also be a dictionary that maps strings to
iterators. If this is just an iterator, then the iterator is
registered by the name ``'main'``.
optimizer (optimizer | dict[str, optimizer]): Optimizer to update
parameters. It can also be a dictionary that maps strings to
optimizers. If this is just an optimizer, then the optimizer is
registered by the name ``'main'``.
converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter
function to build input arrays. Each batch extracted by the main
iterator and the ``device`` option are passed to this function.
:func:`chainer.dataset.concat_examples` is used by default.
device (int or dict): The destination device info to send variables. In the
case of cpu or single gpu, `device=-1 or 0`, respectively.
In the case of multi-gpu, `device={"main":0, "sub_1": 1, ...}`.
accum_grad (int):The number of gradient accumulation. if set to 2, the network
parameters will be updated once in twice,
i.e. actual batchsize will be doubled.
"""
def
__init__
(
self
,
train_iter
,
optimizer
,
converter
,
device
,
accum_grad
=
1
):
"""Initialize Custom Updater."""
super
(
CustomUpdater
,
self
).
__init__
(
train_iter
,
optimizer
,
converter
=
converter
,
device
=
device
)
self
.
accum_grad
=
accum_grad
self
.
forward_count
=
0
self
.
start
=
True
self
.
device
=
device
logging
.
debug
(
"using custom converter for transformer"
)
# The core part of the update routine can be customized by overriding.
def
update_core
(
self
):
"""Process main update routine for Custom Updater."""
train_iter
=
self
.
get_iterator
(
"main"
)
optimizer
=
self
.
get_optimizer
(
"main"
)
# Get batch and convert into variables
batch
=
train_iter
.
next
()
x
=
self
.
converter
(
batch
,
self
.
device
)
if
self
.
start
:
optimizer
.
target
.
cleargrads
()
self
.
start
=
False
# Compute the loss at this time step and accumulate it
loss
=
optimizer
.
target
(
*
x
)
/
self
.
accum_grad
loss
.
backward
()
# Backprop
self
.
forward_count
+=
1
if
self
.
forward_count
!=
self
.
accum_grad
:
return
self
.
forward_count
=
0
# compute the gradient norm to check if it is normal or not
grad_norm
=
np
.
sqrt
(
sum_sqnorm
([
p
.
grad
for
p
in
optimizer
.
target
.
params
(
False
)])
)
logging
.
info
(
"grad norm={}"
.
format
(
grad_norm
))
if
math
.
isnan
(
grad_norm
):
logging
.
warning
(
"grad norm is nan. Do not update model."
)
else
:
optimizer
.
update
()
optimizer
.
target
.
cleargrads
()
# Clear the parameter gradients
def
update
(
self
):
"""Update step for Custom Updater."""
self
.
update_core
()
if
self
.
forward_count
==
0
:
self
.
iteration
+=
1
class
CustomParallelUpdater
(
training
.
updaters
.
MultiprocessParallelUpdater
):
"""Custom Parallel Updater for chainer.
Defines the main update routine.
Args:
train_iter (iterator | dict[str, iterator]): Dataset iterator for the
training dataset. It can also be a dictionary that maps strings to
iterators. If this is just an iterator, then the iterator is
registered by the name ``'main'``.
optimizer (optimizer | dict[str, optimizer]): Optimizer to update
parameters. It can also be a dictionary that maps strings to
optimizers. If this is just an optimizer, then the optimizer is
registered by the name ``'main'``.
converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter
function to build input arrays. Each batch extracted by the main
iterator and the ``device`` option are passed to this function.
:func:`chainer.dataset.concat_examples` is used by default.
device (torch.device): Device to which the training data is sent. Negative value
indicates the host memory (CPU).
accum_grad (int):The number of gradient accumulation. if set to 2, the network
parameters will be updated once in twice,
i.e. actual batchsize will be doubled.
"""
def
__init__
(
self
,
train_iters
,
optimizer
,
converter
,
devices
,
accum_grad
=
1
):
"""Initialize custom parallel updater."""
from
cupy.cuda
import
nccl
super
(
CustomParallelUpdater
,
self
).
__init__
(
train_iters
,
optimizer
,
converter
=
converter
,
devices
=
devices
)
self
.
accum_grad
=
accum_grad
self
.
forward_count
=
0
self
.
nccl
=
nccl
logging
.
debug
(
"using custom parallel updater for transformer"
)
# The core part of the update routine can be customized by overriding.
def
update_core
(
self
):
"""Process main update routine for Custom Parallel Updater."""
self
.
setup_workers
()
self
.
_send_message
((
"update"
,
None
))
with
cuda
.
Device
(
self
.
_devices
[
0
]):
# For reducing memory
optimizer
=
self
.
get_optimizer
(
"main"
)
batch
=
self
.
get_iterator
(
"main"
).
next
()
x
=
self
.
converter
(
batch
,
self
.
_devices
[
0
])
loss
=
self
.
_master
(
*
x
)
/
self
.
accum_grad
loss
.
backward
()
# NCCL: reduce grads
null_stream
=
cuda
.
Stream
.
null
if
self
.
comm
is
not
None
:
gg
=
gather_grads
(
self
.
_master
)
self
.
comm
.
reduce
(
gg
.
data
.
ptr
,
gg
.
data
.
ptr
,
gg
.
size
,
self
.
nccl
.
NCCL_FLOAT
,
self
.
nccl
.
NCCL_SUM
,
0
,
null_stream
.
ptr
,
)
scatter_grads
(
self
.
_master
,
gg
)
del
gg
# update parameters
self
.
forward_count
+=
1
if
self
.
forward_count
!=
self
.
accum_grad
:
return
self
.
forward_count
=
0
# check gradient value
grad_norm
=
np
.
sqrt
(
sum_sqnorm
([
p
.
grad
for
p
in
optimizer
.
target
.
params
(
False
)])
)
logging
.
info
(
"grad norm={}"
.
format
(
grad_norm
))
# update
if
math
.
isnan
(
grad_norm
):
logging
.
warning
(
"grad norm is nan. Do not update model."
)
else
:
optimizer
.
update
()
self
.
_master
.
cleargrads
()
if
self
.
comm
is
not
None
:
gp
=
gather_params
(
self
.
_master
)
self
.
comm
.
bcast
(
gp
.
data
.
ptr
,
gp
.
size
,
self
.
nccl
.
NCCL_FLOAT
,
0
,
null_stream
.
ptr
)
def
update
(
self
):
"""Update step for Custom Parallel Updater."""
self
.
update_core
()
if
self
.
forward_count
==
0
:
self
.
iteration
+=
1
class
VaswaniRule
(
extension
.
Extension
):
"""Trainer extension to shift an optimizer attribute magically by Vaswani.
Args:
attr (str): Name of the attribute to shift.
rate (float): Rate of the exponential shift. This value is multiplied
to the attribute at each call.
init (float): Initial value of the attribute. If it is ``None``, the
extension extracts the attribute at the first call and uses it as
the initial value.
target (float): Target value of the attribute. If the attribute reaches
this value, the shift stops.
optimizer (~chainer.Optimizer): Target optimizer to adjust the
attribute. If it is ``None``, the main optimizer of the updater is
used.
"""
def
__init__
(
self
,
attr
,
d
,
warmup_steps
=
4000
,
init
=
None
,
target
=
None
,
optimizer
=
None
,
scale
=
1.0
,
):
"""Initialize Vaswani rule extension."""
self
.
_attr
=
attr
self
.
_d_inv05
=
d
**
(
-
0.5
)
*
scale
self
.
_warmup_steps_inv15
=
warmup_steps
**
(
-
1.5
)
self
.
_init
=
init
self
.
_target
=
target
self
.
_optimizer
=
optimizer
self
.
_t
=
0
self
.
_last_value
=
None
def
initialize
(
self
,
trainer
):
"""Initialize Optimizer values."""
optimizer
=
self
.
_get_optimizer
(
trainer
)
# ensure that _init is set
if
self
.
_init
is
None
:
self
.
_init
=
self
.
_d_inv05
*
(
1.0
*
self
.
_warmup_steps_inv15
)
if
self
.
_last_value
is
not
None
:
# resuming from a snapshot
self
.
_update_value
(
optimizer
,
self
.
_last_value
)
else
:
self
.
_update_value
(
optimizer
,
self
.
_init
)
def
__call__
(
self
,
trainer
):
"""Forward extension."""
self
.
_t
+=
1
optimizer
=
self
.
_get_optimizer
(
trainer
)
value
=
self
.
_d_inv05
*
min
(
self
.
_t
**
(
-
0.5
),
self
.
_t
*
self
.
_warmup_steps_inv15
)
self
.
_update_value
(
optimizer
,
value
)
def
serialize
(
self
,
serializer
):
"""Serialize extension."""
self
.
_t
=
serializer
(
"_t"
,
self
.
_t
)
self
.
_last_value
=
serializer
(
"_last_value"
,
self
.
_last_value
)
def
_get_optimizer
(
self
,
trainer
):
"""Obtain optimizer from trainer."""
return
self
.
_optimizer
or
trainer
.
updater
.
get_optimizer
(
"main"
)
def
_update_value
(
self
,
optimizer
,
value
):
"""Update requested variable values."""
setattr
(
optimizer
,
self
.
_attr
,
value
)
self
.
_last_value
=
value
class
CustomConverter
(
object
):
"""Custom Converter.
Args:
subsampling_factor (int): The subsampling factor.
"""
def
__init__
(
self
):
"""Initialize subsampling."""
pass
def
__call__
(
self
,
batch
,
device
):
"""Perform subsampling.
Args:
batch (list): Batch that will be sabsampled.
device (chainer.backend.Device): CPU or GPU device.
Returns:
chainer.Variable: xp.array that are padded and subsampled from batch.
xp.array: xp.array of the length of the mini-batches.
chainer.Variable: xp.array that are padded and subsampled from batch.
"""
# For transformer, data is processed in CPU.
# batch should be located in list
assert
len
(
batch
)
==
1
xs
,
ys
=
batch
[
0
]
xs
=
F
.
pad_sequence
(
xs
,
padding
=-
1
).
data
# get batch of lengths of input sequences
ilens
=
np
.
array
([
x
.
shape
[
0
]
for
x
in
xs
],
dtype
=
np
.
int32
)
return
xs
,
ilens
,
ys
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/ctc_prefix_score.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import
numpy
as
np
import
torch
class
CTCPrefixScoreTH
(
object
):
"""Batch processing of CTCPrefixScore
which is based on Algorithm 2 in WATANABE et al.
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
but extended to efficiently compute the label probablities for multiple
hypotheses simultaneously
See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
"""
def
__init__
(
self
,
x
,
xlens
,
blank
,
eos
,
margin
=
0
):
"""Construct CTC prefix scorer
:param torch.Tensor x: input label posterior sequences (B, T, O)
:param torch.Tensor xlens: input lengths (B,)
:param int blank: blank label id
:param int eos: end-of-sequence id
:param int margin: margin parameter for windowing (0 means no windowing)
"""
# In the comment lines,
# we assume T: input_length, B: batch size, W: beam width, O: output dim.
self
.
logzero
=
-
10000000000.0
self
.
blank
=
blank
self
.
eos
=
eos
self
.
batch
=
x
.
size
(
0
)
self
.
input_length
=
x
.
size
(
1
)
self
.
odim
=
x
.
size
(
2
)
self
.
dtype
=
x
.
dtype
self
.
device
=
(
torch
.
device
(
"cuda:%d"
%
x
.
get_device
())
if
x
.
is_cuda
else
torch
.
device
(
"cpu"
)
)
# Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops
for
i
,
l
in
enumerate
(
xlens
):
if
l
<
self
.
input_length
:
x
[
i
,
l
:,
:]
=
self
.
logzero
x
[
i
,
l
:,
blank
]
=
0
# Reshape input x
xn
=
x
.
transpose
(
0
,
1
)
# (B, T, O) -> (T, B, O)
xb
=
xn
[:,
:,
self
.
blank
].
unsqueeze
(
2
).
expand
(
-
1
,
-
1
,
self
.
odim
)
self
.
x
=
torch
.
stack
([
xn
,
xb
])
# (2, T, B, O)
self
.
end_frames
=
torch
.
as_tensor
(
xlens
)
-
1
# Setup CTC windowing
self
.
margin
=
margin
if
margin
>
0
:
self
.
frame_ids
=
torch
.
arange
(
self
.
input_length
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
# Base indices for index conversion
self
.
idx_bh
=
None
self
.
idx_b
=
torch
.
arange
(
self
.
batch
,
device
=
self
.
device
)
self
.
idx_bo
=
(
self
.
idx_b
*
self
.
odim
).
unsqueeze
(
1
)
def
__call__
(
self
,
y
,
state
,
scoring_ids
=
None
,
att_w
=
None
):
"""Compute CTC prefix scores for next labels
:param list y: prefix label sequences
:param tuple state: previous CTC state
:param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O)
:param torch.Tensor att_w: attention weights to decide CTC window
:return new_state, ctc_local_scores (BW, O)
"""
output_length
=
len
(
y
[
0
])
-
1
# ignore sos
last_ids
=
[
yi
[
-
1
]
for
yi
in
y
]
# last output label ids
n_bh
=
len
(
last_ids
)
# batch * hyps
n_hyps
=
n_bh
//
self
.
batch
# assuming each utterance has the same # of hyps
self
.
scoring_num
=
scoring_ids
.
size
(
-
1
)
if
scoring_ids
is
not
None
else
0
# prepare state info
if
state
is
None
:
r_prev
=
torch
.
full
(
(
self
.
input_length
,
2
,
self
.
batch
,
n_hyps
),
self
.
logzero
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
r_prev
[:,
1
]
=
torch
.
cumsum
(
self
.
x
[
0
,
:,
:,
self
.
blank
],
0
).
unsqueeze
(
2
)
r_prev
=
r_prev
.
view
(
-
1
,
2
,
n_bh
)
s_prev
=
0.0
f_min_prev
=
0
f_max_prev
=
1
else
:
r_prev
,
s_prev
,
f_min_prev
,
f_max_prev
=
state
# select input dimensions for scoring
if
self
.
scoring_num
>
0
:
scoring_idmap
=
torch
.
full
(
(
n_bh
,
self
.
odim
),
-
1
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
snum
=
self
.
scoring_num
if
self
.
idx_bh
is
None
or
n_bh
>
len
(
self
.
idx_bh
):
self
.
idx_bh
=
torch
.
arange
(
n_bh
,
device
=
self
.
device
).
view
(
-
1
,
1
)
scoring_idmap
[
self
.
idx_bh
[:
n_bh
],
scoring_ids
]
=
torch
.
arange
(
snum
,
device
=
self
.
device
)
scoring_idx
=
(
scoring_ids
+
self
.
idx_bo
.
repeat
(
1
,
n_hyps
).
view
(
-
1
,
1
)
).
view
(
-
1
)
x_
=
torch
.
index_select
(
self
.
x
.
view
(
2
,
-
1
,
self
.
batch
*
self
.
odim
),
2
,
scoring_idx
).
view
(
2
,
-
1
,
n_bh
,
snum
)
else
:
scoring_ids
=
None
scoring_idmap
=
None
snum
=
self
.
odim
x_
=
self
.
x
.
unsqueeze
(
3
).
repeat
(
1
,
1
,
1
,
n_hyps
,
1
).
view
(
2
,
-
1
,
n_bh
,
snum
)
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
r
=
torch
.
full
(
(
self
.
input_length
,
2
,
n_bh
,
snum
),
self
.
logzero
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
if
output_length
==
0
:
r
[
0
,
0
]
=
x_
[
0
,
0
]
r_sum
=
torch
.
logsumexp
(
r_prev
,
1
)
log_phi
=
r_sum
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
snum
)
if
scoring_ids
is
not
None
:
for
idx
in
range
(
n_bh
):
pos
=
scoring_idmap
[
idx
,
last_ids
[
idx
]]
if
pos
>=
0
:
log_phi
[:,
idx
,
pos
]
=
r_prev
[:,
1
,
idx
]
else
:
for
idx
in
range
(
n_bh
):
log_phi
[:,
idx
,
last_ids
[
idx
]]
=
r_prev
[:,
1
,
idx
]
# decide start and end frames based on attention weights
if
att_w
is
not
None
and
self
.
margin
>
0
:
f_arg
=
torch
.
matmul
(
att_w
,
self
.
frame_ids
)
f_min
=
max
(
int
(
f_arg
.
min
().
cpu
()),
f_min_prev
)
f_max
=
max
(
int
(
f_arg
.
max
().
cpu
()),
f_max_prev
)
start
=
min
(
f_max_prev
,
max
(
f_min
-
self
.
margin
,
output_length
,
1
))
end
=
min
(
f_max
+
self
.
margin
,
self
.
input_length
)
else
:
f_min
=
f_max
=
0
start
=
max
(
output_length
,
1
)
end
=
self
.
input_length
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
for
t
in
range
(
start
,
end
):
rp
=
r
[
t
-
1
]
rr
=
torch
.
stack
([
rp
[
0
],
log_phi
[
t
-
1
],
rp
[
0
],
rp
[
1
]]).
view
(
2
,
2
,
n_bh
,
snum
)
r
[
t
]
=
torch
.
logsumexp
(
rr
,
1
)
+
x_
[:,
t
]
# compute log prefix probabilities log(psi)
log_phi_x
=
torch
.
cat
((
log_phi
[
0
].
unsqueeze
(
0
),
log_phi
[:
-
1
]),
dim
=
0
)
+
x_
[
0
]
if
scoring_ids
is
not
None
:
log_psi
=
torch
.
full
(
(
n_bh
,
self
.
odim
),
self
.
logzero
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
log_psi_
=
torch
.
logsumexp
(
torch
.
cat
((
log_phi_x
[
start
:
end
],
r
[
start
-
1
,
0
].
unsqueeze
(
0
)),
dim
=
0
),
dim
=
0
,
)
for
si
in
range
(
n_bh
):
log_psi
[
si
,
scoring_ids
[
si
]]
=
log_psi_
[
si
]
else
:
log_psi
=
torch
.
logsumexp
(
torch
.
cat
((
log_phi_x
[
start
:
end
],
r
[
start
-
1
,
0
].
unsqueeze
(
0
)),
dim
=
0
),
dim
=
0
,
)
for
si
in
range
(
n_bh
):
log_psi
[
si
,
self
.
eos
]
=
r_sum
[
self
.
end_frames
[
si
//
n_hyps
],
si
]
# exclude blank probs
log_psi
[:,
self
.
blank
]
=
self
.
logzero
return
(
log_psi
-
s_prev
),
(
r
,
log_psi
,
f_min
,
f_max
,
scoring_idmap
)
def
index_select_state
(
self
,
state
,
best_ids
):
"""Select CTC states according to best ids
:param state : CTC state
:param best_ids : index numbers selected by beam pruning (B, W)
:return selected_state
"""
r
,
s
,
f_min
,
f_max
,
scoring_idmap
=
state
# convert ids to BHO space
n_bh
=
len
(
s
)
n_hyps
=
n_bh
//
self
.
batch
vidx
=
(
best_ids
+
(
self
.
idx_b
*
(
n_hyps
*
self
.
odim
)).
view
(
-
1
,
1
)).
view
(
-
1
)
# select hypothesis scores
s_new
=
torch
.
index_select
(
s
.
view
(
-
1
),
0
,
vidx
)
s_new
=
s_new
.
view
(
-
1
,
1
).
repeat
(
1
,
self
.
odim
).
view
(
n_bh
,
self
.
odim
)
# convert ids to BHS space (S: scoring_num)
if
scoring_idmap
is
not
None
:
snum
=
self
.
scoring_num
hyp_idx
=
(
best_ids
//
self
.
odim
+
(
self
.
idx_b
*
n_hyps
).
view
(
-
1
,
1
)).
view
(
-
1
)
label_ids
=
torch
.
fmod
(
best_ids
,
self
.
odim
).
view
(
-
1
)
score_idx
=
scoring_idmap
[
hyp_idx
,
label_ids
]
score_idx
[
score_idx
==
-
1
]
=
0
vidx
=
score_idx
+
hyp_idx
*
snum
else
:
snum
=
self
.
odim
# select forward probabilities
r_new
=
torch
.
index_select
(
r
.
view
(
-
1
,
2
,
n_bh
*
snum
),
2
,
vidx
).
view
(
-
1
,
2
,
n_bh
)
return
r_new
,
s_new
,
f_min
,
f_max
def
extend_prob
(
self
,
x
):
"""Extend CTC prob.
:param torch.Tensor x: input label posterior sequences (B, T, O)
"""
if
self
.
x
.
shape
[
1
]
<
x
.
shape
[
1
]:
# self.x (2,T,B,O); x (B,T,O)
# Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops
xlens
=
[
x
.
size
(
1
)]
for
i
,
l
in
enumerate
(
xlens
):
if
l
<
self
.
input_length
:
x
[
i
,
l
:,
:]
=
self
.
logzero
x
[
i
,
l
:,
self
.
blank
]
=
0
tmp_x
=
self
.
x
xn
=
x
.
transpose
(
0
,
1
)
# (B, T, O) -> (T, B, O)
xb
=
xn
[:,
:,
self
.
blank
].
unsqueeze
(
2
).
expand
(
-
1
,
-
1
,
self
.
odim
)
self
.
x
=
torch
.
stack
([
xn
,
xb
])
# (2, T, B, O)
self
.
x
[:,
:
tmp_x
.
shape
[
1
],
:,
:]
=
tmp_x
self
.
input_length
=
x
.
size
(
1
)
self
.
end_frames
=
torch
.
as_tensor
(
xlens
)
-
1
def
extend_state
(
self
,
state
):
"""Compute CTC prefix state.
:param state : CTC state
:return ctc_state
"""
if
state
is
None
:
# nothing to do
return
state
else
:
r_prev
,
s_prev
,
f_min_prev
,
f_max_prev
=
state
r_prev_new
=
torch
.
full
(
(
self
.
input_length
,
2
),
self
.
logzero
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
start
=
max
(
r_prev
.
shape
[
0
],
1
)
r_prev_new
[
0
:
start
]
=
r_prev
for
t
in
range
(
start
,
self
.
input_length
):
r_prev_new
[
t
,
1
]
=
r_prev_new
[
t
-
1
,
1
]
+
self
.
x
[
0
,
t
,
:,
self
.
blank
]
return
(
r_prev_new
,
s_prev
,
f_min_prev
,
f_max_prev
)
class
CTCPrefixScore
(
object
):
"""Compute CTC label sequence scores
which is based on Algorithm 2 in WATANABE et al.
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
but extended to efficiently compute the probablities of multiple labels
simultaneously
"""
def
__init__
(
self
,
x
,
blank
,
eos
,
xp
):
self
.
xp
=
xp
self
.
logzero
=
-
10000000000.0
self
.
blank
=
blank
self
.
eos
=
eos
self
.
input_length
=
len
(
x
)
self
.
x
=
x
def
initial_state
(
self
):
"""Obtain an initial CTC state
:return: CTC state
"""
# initial CTC state is made of a frame x 2 tensor that corresponds to
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
# superscripts n and b (non-blank and blank), respectively.
r
=
self
.
xp
.
full
((
self
.
input_length
,
2
),
self
.
logzero
,
dtype
=
np
.
float32
)
r
[
0
,
1
]
=
self
.
x
[
0
,
self
.
blank
]
for
i
in
range
(
1
,
self
.
input_length
):
r
[
i
,
1
]
=
r
[
i
-
1
,
1
]
+
self
.
x
[
i
,
self
.
blank
]
return
r
def
__call__
(
self
,
y
,
cs
,
r_prev
):
"""Compute CTC prefix scores for next labels
:param y : prefix label sequence
:param cs : array of next labels
:param r_prev: previous CTC state
:return ctc_scores, ctc_states
"""
# initialize CTC states
output_length
=
len
(
y
)
-
1
# ignore sos
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
# that corresponds to r_t^n(h) and r_t^b(h).
r
=
self
.
xp
.
ndarray
((
self
.
input_length
,
2
,
len
(
cs
)),
dtype
=
np
.
float32
)
xs
=
self
.
x
[:,
cs
]
if
output_length
==
0
:
r
[
0
,
0
]
=
xs
[
0
]
r
[
0
,
1
]
=
self
.
logzero
else
:
r
[
output_length
-
1
]
=
self
.
logzero
# prepare forward probabilities for the last label
r_sum
=
self
.
xp
.
logaddexp
(
r_prev
[:,
0
],
r_prev
[:,
1
]
)
# log(r_t^n(g) + r_t^b(g))
last
=
y
[
-
1
]
if
output_length
>
0
and
last
in
cs
:
log_phi
=
self
.
xp
.
ndarray
((
self
.
input_length
,
len
(
cs
)),
dtype
=
np
.
float32
)
for
i
in
range
(
len
(
cs
)):
log_phi
[:,
i
]
=
r_sum
if
cs
[
i
]
!=
last
else
r_prev
[:,
1
]
else
:
log_phi
=
r_sum
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
# and log prefix probabilities log(psi)
start
=
max
(
output_length
,
1
)
log_psi
=
r
[
start
-
1
,
0
]
for
t
in
range
(
start
,
self
.
input_length
):
r
[
t
,
0
]
=
self
.
xp
.
logaddexp
(
r
[
t
-
1
,
0
],
log_phi
[
t
-
1
])
+
xs
[
t
]
r
[
t
,
1
]
=
(
self
.
xp
.
logaddexp
(
r
[
t
-
1
,
0
],
r
[
t
-
1
,
1
])
+
self
.
x
[
t
,
self
.
blank
]
)
log_psi
=
self
.
xp
.
logaddexp
(
log_psi
,
log_phi
[
t
-
1
]
+
xs
[
t
])
# get P(...eos|X) that ends with the prefix itself
eos_pos
=
self
.
xp
.
where
(
cs
==
self
.
eos
)[
0
]
if
len
(
eos_pos
)
>
0
:
log_psi
[
eos_pos
]
=
r_sum
[
-
1
]
# log(r_T^n(g) + r_T^b(g))
# exclude blank probs
blank_pos
=
self
.
xp
.
where
(
cs
==
self
.
blank
)[
0
]
if
len
(
blank_pos
)
>
0
:
log_psi
[
blank_pos
]
=
self
.
logzero
# return the log prefix probability and CTC states, where the label axis
# of the CTC states is moved to the first axis to slice it easily
return
log_psi
,
self
.
xp
.
rollaxis
(
r
,
2
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/e2e_asr_common.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Common functions for ASR."""
import
json
import
logging
import
sys
from
itertools
import
groupby
import
numpy
as
np
def
end_detect
(
ended_hyps
,
i
,
M
=
3
,
D_end
=
np
.
log
(
1
*
np
.
exp
(
-
10
))):
"""End detection.
described in Eq. (50) of S. Watanabe et al
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
:param ended_hyps:
:param i:
:param M:
:param D_end:
:return:
"""
if
len
(
ended_hyps
)
==
0
:
return
False
count
=
0
best_hyp
=
sorted
(
ended_hyps
,
key
=
lambda
x
:
x
[
"score"
],
reverse
=
True
)[
0
]
for
m
in
range
(
M
):
# get ended_hyps with their length is i - m
hyp_length
=
i
-
m
hyps_same_length
=
[
x
for
x
in
ended_hyps
if
len
(
x
[
"yseq"
])
==
hyp_length
]
if
len
(
hyps_same_length
)
>
0
:
best_hyp_same_length
=
sorted
(
hyps_same_length
,
key
=
lambda
x
:
x
[
"score"
],
reverse
=
True
)[
0
]
if
best_hyp_same_length
[
"score"
]
-
best_hyp
[
"score"
]
<
D_end
:
count
+=
1
if
count
==
M
:
return
True
else
:
return
False
# TODO(takaaki-hori): add different smoothing methods
def
label_smoothing_dist
(
odim
,
lsm_type
,
transcript
=
None
,
blank
=
0
):
"""Obtain label distribution for loss smoothing.
:param odim:
:param lsm_type:
:param blank:
:param transcript:
:return:
"""
if
transcript
is
not
None
:
with
open
(
transcript
,
"rb"
)
as
f
:
trans_json
=
json
.
load
(
f
)[
"utts"
]
if
lsm_type
==
"unigram"
:
assert
transcript
is
not
None
,
(
"transcript is required for %s label smoothing"
%
lsm_type
)
labelcount
=
np
.
zeros
(
odim
)
for
k
,
v
in
trans_json
.
items
():
ids
=
np
.
array
([
int
(
n
)
for
n
in
v
[
"output"
][
0
][
"tokenid"
].
split
()])
# to avoid an error when there is no text in an uttrance
if
len
(
ids
)
>
0
:
labelcount
[
ids
]
+=
1
labelcount
[
odim
-
1
]
=
len
(
transcript
)
# count <eos>
labelcount
[
labelcount
==
0
]
=
1
# flooring
labelcount
[
blank
]
=
0
# remove counts for blank
labeldist
=
labelcount
.
astype
(
np
.
float32
)
/
np
.
sum
(
labelcount
)
else
:
logging
.
error
(
"Error: unexpected label smoothing type: %s"
%
lsm_type
)
sys
.
exit
()
return
labeldist
def
get_vgg2l_odim
(
idim
,
in_channel
=
3
,
out_channel
=
128
):
"""Return the output size of the VGG frontend.
:param in_channel: input channel size
:param out_channel: output channel size
:return: output size
:rtype int
"""
idim
=
idim
/
in_channel
idim
=
np
.
ceil
(
np
.
array
(
idim
,
dtype
=
np
.
float32
)
/
2
)
# 1st max pooling
idim
=
np
.
ceil
(
np
.
array
(
idim
,
dtype
=
np
.
float32
)
/
2
)
# 2nd max pooling
return
int
(
idim
)
*
out_channel
# numer of channels
class
ErrorCalculator
(
object
):
"""Calculate CER and WER for E2E_ASR and CTC models during training.
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list:
:param sym_space:
:param sym_blank:
:return:
"""
def
__init__
(
self
,
char_list
,
sym_space
,
sym_blank
,
report_cer
=
False
,
report_wer
=
False
):
"""Construct an ErrorCalculator object."""
super
(
ErrorCalculator
,
self
).
__init__
()
self
.
report_cer
=
report_cer
self
.
report_wer
=
report_wer
self
.
char_list
=
char_list
self
.
space
=
sym_space
self
.
blank
=
sym_blank
# NOTE (Shih-Lun): else case is for OpenAI Whisper ASR model,
# which doesn't use <blank> token
if
self
.
blank
in
self
.
char_list
:
self
.
idx_blank
=
self
.
char_list
.
index
(
self
.
blank
)
else
:
self
.
idx_blank
=
None
if
self
.
space
in
self
.
char_list
:
self
.
idx_space
=
self
.
char_list
.
index
(
self
.
space
)
else
:
self
.
idx_space
=
None
def
__call__
(
self
,
ys_hat
,
ys_pad
,
is_ctc
=
False
):
"""Calculate sentence-level WER/CER score.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:param bool is_ctc: calculate CER score for CTC
:return: sentence-level WER score
:rtype float
:return: sentence-level CER score
:rtype float
"""
cer
,
wer
=
None
,
None
if
is_ctc
:
return
self
.
calculate_cer_ctc
(
ys_hat
,
ys_pad
)
elif
not
self
.
report_cer
and
not
self
.
report_wer
:
return
cer
,
wer
seqs_hat
,
seqs_true
=
self
.
convert_to_char
(
ys_hat
,
ys_pad
)
if
self
.
report_cer
:
cer
=
self
.
calculate_cer
(
seqs_hat
,
seqs_true
)
if
self
.
report_wer
:
wer
=
self
.
calculate_wer
(
seqs_hat
,
seqs_true
)
return
cer
,
wer
def
calculate_cer_ctc
(
self
,
ys_hat
,
ys_pad
):
"""Calculate sentence-level CER score for CTC.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:return: average sentence-level CER score
:rtype float
"""
import
editdistance
cers
,
char_ref_lens
=
[],
[]
for
i
,
y
in
enumerate
(
ys_hat
):
y_hat
=
[
x
[
0
]
for
x
in
groupby
(
y
)]
y_true
=
ys_pad
[
i
]
seq_hat
,
seq_true
=
[],
[]
for
idx
in
y_hat
:
idx
=
int
(
idx
)
if
idx
!=
-
1
and
idx
!=
self
.
idx_blank
and
idx
!=
self
.
idx_space
:
seq_hat
.
append
(
self
.
char_list
[
int
(
idx
)])
for
idx
in
y_true
:
idx
=
int
(
idx
)
if
idx
!=
-
1
and
idx
!=
self
.
idx_blank
and
idx
!=
self
.
idx_space
:
seq_true
.
append
(
self
.
char_list
[
int
(
idx
)])
hyp_chars
=
""
.
join
(
seq_hat
)
ref_chars
=
""
.
join
(
seq_true
)
if
len
(
ref_chars
)
>
0
:
cers
.
append
(
editdistance
.
eval
(
hyp_chars
,
ref_chars
))
char_ref_lens
.
append
(
len
(
ref_chars
))
cer_ctc
=
float
(
sum
(
cers
))
/
sum
(
char_ref_lens
)
if
cers
else
None
return
cer_ctc
def
convert_to_char
(
self
,
ys_hat
,
ys_pad
):
"""Convert index to character.
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
:param torch.Tensor seqs_true: reference (batch, seqlen)
:return: token list of prediction
:rtype list
:return: token list of reference
:rtype list
"""
seqs_hat
,
seqs_true
=
[],
[]
for
i
,
y_hat
in
enumerate
(
ys_hat
):
y_true
=
ys_pad
[
i
]
eos_true
=
np
.
where
(
y_true
==
-
1
)[
0
]
ymax
=
eos_true
[
0
]
if
len
(
eos_true
)
>
0
else
len
(
y_true
)
# NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
[:
ymax
]]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
blank
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
seqs_hat
.
append
(
seq_hat_text
)
seqs_true
.
append
(
seq_true_text
)
return
seqs_hat
,
seqs_true
def
calculate_cer
(
self
,
seqs_hat
,
seqs_true
):
"""Calculate sentence-level CER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level CER score
:rtype float
"""
import
editdistance
char_eds
,
char_ref_lens
=
[],
[]
for
i
,
seq_hat_text
in
enumerate
(
seqs_hat
):
seq_true_text
=
seqs_true
[
i
]
hyp_chars
=
seq_hat_text
.
replace
(
" "
,
""
)
ref_chars
=
seq_true_text
.
replace
(
" "
,
""
)
char_eds
.
append
(
editdistance
.
eval
(
hyp_chars
,
ref_chars
))
char_ref_lens
.
append
(
len
(
ref_chars
))
return
float
(
sum
(
char_eds
))
/
sum
(
char_ref_lens
)
def
calculate_wer
(
self
,
seqs_hat
,
seqs_true
):
"""Calculate sentence-level WER score.
:param list seqs_hat: prediction
:param list seqs_true: reference
:return: average sentence-level WER score
:rtype float
"""
import
editdistance
word_eds
,
word_ref_lens
=
[],
[]
for
i
,
seq_hat_text
in
enumerate
(
seqs_hat
):
seq_true_text
=
seqs_true
[
i
]
hyp_words
=
seq_hat_text
.
split
()
ref_words
=
seq_true_text
.
split
()
word_eds
.
append
(
editdistance
.
eval
(
hyp_words
,
ref_words
))
word_ref_lens
.
append
(
len
(
ref_words
))
return
float
(
sum
(
word_eds
))
/
sum
(
word_ref_lens
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/e2e_mt_common.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Common functions for ST and MT."""
import
nltk
import
numpy
as
np
class
ErrorCalculator
(
object
):
"""Calculate BLEU for ST and MT models during training.
:param y_hats: numpy array with predicted text
:param y_pads: numpy array with true (target) text
:param char_list: vocabulary list
:param sym_space: space symbol
:param sym_pad: pad symbol
:param report_bleu: report BLUE score if True
"""
def
__init__
(
self
,
char_list
,
sym_space
,
sym_pad
,
report_bleu
=
False
):
"""Construct an ErrorCalculator object."""
super
(
ErrorCalculator
,
self
).
__init__
()
self
.
char_list
=
char_list
self
.
space
=
sym_space
self
.
pad
=
sym_pad
self
.
report_bleu
=
report_bleu
if
self
.
space
in
self
.
char_list
:
self
.
idx_space
=
self
.
char_list
.
index
(
self
.
space
)
else
:
self
.
idx_space
=
None
def
__call__
(
self
,
ys_hat
,
ys_pad
):
"""Calculate corpus-level BLEU score.
:param torch.Tensor ys_hat: prediction (batch, seqlen)
:param torch.Tensor ys_pad: reference (batch, seqlen)
:return: corpus-level BLEU score in a mini-batch
:rtype float
"""
bleu
=
None
if
not
self
.
report_bleu
:
return
bleu
bleu
=
self
.
calculate_corpus_bleu
(
ys_hat
,
ys_pad
)
return
bleu
def
calculate_corpus_bleu
(
self
,
ys_hat
,
ys_pad
):
"""Calculate corpus-level BLEU score in a mini-batch.
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
:param torch.Tensor seqs_true: reference (batch, seqlen)
:return: corpus-level BLEU score
:rtype float
"""
seqs_hat
,
seqs_true
=
[],
[]
for
i
,
y_hat
in
enumerate
(
ys_hat
):
y_true
=
ys_pad
[
i
]
eos_true
=
np
.
where
(
y_true
==
-
1
)[
0
]
ymax
=
eos_true
[
0
]
if
len
(
eos_true
)
>
0
else
len
(
y_true
)
# NOTE: padding index (-1) in y_true is used to pad y_hat
# because y_hats is not padded with -1
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
[:
ymax
]]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
pad
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
seqs_hat
.
append
(
seq_hat_text
)
seqs_true
.
append
(
seq_true_text
)
bleu
=
nltk
.
bleu_score
.
corpus_bleu
([[
ref
]
for
ref
in
seqs_true
],
seqs_hat
)
return
bleu
*
100
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/lm_interface.py
0 → 100644
View file @
60a2c57a
"""Language model interface."""
import
argparse
from
espnet.nets.scorer_interface
import
ScorerInterface
from
espnet.utils.dynamic_import
import
dynamic_import
from
espnet.utils.fill_missing_args
import
fill_missing_args
class
LMInterface
(
ScorerInterface
):
"""LM Interface for ESPnet model implementation."""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments to command line argument parser."""
return
parser
@
classmethod
def
build
(
cls
,
n_vocab
:
int
,
**
kwargs
):
"""Initialize this class with python-level args.
Args:
idim (int): The number of vocabulary.
Returns:
LMinterface: A new instance of LMInterface.
"""
# local import to avoid cyclic import in lm_train
from
espnet.bin.lm_train
import
get_parser
def
wrap
(
parser
):
return
get_parser
(
parser
,
required
=
False
)
args
=
argparse
.
Namespace
(
**
kwargs
)
args
=
fill_missing_args
(
args
,
wrap
)
args
=
fill_missing_args
(
args
,
cls
.
add_arguments
)
return
cls
(
n_vocab
,
args
)
def
forward
(
self
,
x
,
t
):
"""Compute LM loss value from buffer sequences.
Args:
x (torch.Tensor): Input ids. (batch, len)
t (torch.Tensor): Target ids. (batch, len)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
raise
NotImplementedError
(
"forward method is not implemented"
)
predefined_lms
=
{
"pytorch"
:
{
"default"
:
"espnet.nets.pytorch_backend.lm.default:DefaultRNNLM"
,
"seq_rnn"
:
"espnet.nets.pytorch_backend.lm.seq_rnn:SequentialRNNLM"
,
"transformer"
:
"espnet.nets.pytorch_backend.lm.transformer:TransformerLM"
,
},
"chainer"
:
{
"default"
:
"espnet.lm.chainer_backend.lm:DefaultRNNLM"
},
}
def
dynamic_import_lm
(
module
,
backend
):
"""Import LM class dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_lms`
backend (str): NN backend. e.g., pytorch, chainer
Returns:
type: LM class
"""
model_class
=
dynamic_import
(
module
,
predefined_lms
.
get
(
backend
,
dict
()))
assert
issubclass
(
model_class
,
LMInterface
),
f
"
{
module
}
does not implement LMInterface"
return
model_class
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/mt_interface.py
0 → 100644
View file @
60a2c57a
"""MT Interface module."""
import
argparse
from
espnet.bin.asr_train
import
get_parser
from
espnet.utils.fill_missing_args
import
fill_missing_args
class
MTInterface
:
"""MT Interface for ESPnet model implementation."""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments to parser."""
return
parser
@
classmethod
def
build
(
cls
,
idim
:
int
,
odim
:
int
,
**
kwargs
):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
def
wrap
(
parser
):
return
get_parser
(
parser
,
required
=
False
)
args
=
argparse
.
Namespace
(
**
kwargs
)
args
=
fill_missing_args
(
args
,
wrap
)
args
=
fill_missing_args
(
args
,
cls
.
add_arguments
)
return
cls
(
idim
,
odim
,
args
)
def
forward
(
self
,
xs
,
ilens
,
ys
):
"""Compute loss for training.
:param xs:
For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim)
For chainer, list of source sequences chainer.Variable
:param ilens: batch of lengths of source sequences (B)
For pytorch, torch.Tensor
For chainer, list of int
:param ys:
For pytorch, batch of padded source sequences torch.Tensor (B, Lmax)
For chainer, list of source sequences chainer.Variable
:return: loss value
:rtype: torch.Tensor for pytorch, chainer.Variable for chainer
"""
raise
NotImplementedError
(
"forward method is not implemented"
)
def
translate
(
self
,
x
,
trans_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Translate x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"translate method is not implemented"
)
def
translate_batch
(
self
,
x
,
trans_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Beam search implementation for batch.
:param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace trans_args: argument namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"Batch decoding is not supported yet."
)
def
calculate_all_attentions
(
self
,
xs
,
ilens
,
ys
):
"""Calculate attention.
:param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise
NotImplementedError
(
"calculate_all_attentions method is not implemented"
)
@
property
def
attention_plot_class
(
self
):
"""Get attention plot class."""
from
espnet.asr.asr_utils
import
PlotAttentionReport
return
PlotAttentionReport
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/argument.py
0 → 100644
View file @
60a2c57a
# Copyright 2020 Hirofumi Inaguma
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Conformer common arguments."""
import
logging
from
distutils.util
import
strtobool
def
add_arguments_conformer_common
(
group
):
"""Add Transformer common arguments."""
group
.
add_argument
(
"--transformer-encoder-pos-enc-layer-type"
,
type
=
str
,
default
=
"abs_pos"
,
choices
=
[
"abs_pos"
,
"scaled_abs_pos"
,
"rel_pos"
],
help
=
"Transformer encoder positional encoding layer type"
,
)
group
.
add_argument
(
"--transformer-encoder-activation-type"
,
type
=
str
,
default
=
"swish"
,
choices
=
[
"relu"
,
"hardtanh"
,
"selu"
,
"swish"
],
help
=
"Transformer encoder activation function type"
,
)
group
.
add_argument
(
"--macaron-style"
,
default
=
False
,
type
=
strtobool
,
help
=
"Whether to use macaron style for positionwise layer"
,
)
# Attention
group
.
add_argument
(
"--zero-triu"
,
default
=
False
,
type
=
strtobool
,
help
=
"If true, zero the uppper triangular part of attention matrix."
,
)
# Relative positional encoding
group
.
add_argument
(
"--rel-pos-type"
,
type
=
str
,
default
=
"legacy"
,
choices
=
[
"legacy"
,
"latest"
],
help
=
"Whether to use the latest relative positional encoding or the legacy one."
"The legacy relative positional encoding will be deprecated in the future."
"More Details can be found in https://github.com/espnet/espnet/pull/2816."
,
)
# CNN module
group
.
add_argument
(
"--use-cnn-module"
,
default
=
False
,
type
=
strtobool
,
help
=
"Use convolution module or not"
,
)
group
.
add_argument
(
"--cnn-module-kernel"
,
default
=
31
,
type
=
int
,
help
=
"Kernel size of convolution module."
,
)
return
group
def
verify_rel_pos_type
(
args
):
"""Verify the relative positional encoding type for compatibility.
Args:
args (Namespace): original arguments
Returns:
args (Namespace): modified arguments
"""
rel_pos_type
=
getattr
(
args
,
"rel_pos_type"
,
None
)
if
rel_pos_type
is
None
or
rel_pos_type
==
"legacy"
:
if
args
.
transformer_encoder_pos_enc_layer_type
==
"rel_pos"
:
args
.
transformer_encoder_pos_enc_layer_type
=
"legacy_rel_pos"
logging
.
warning
(
"Using legacy_rel_pos and it will be deprecated in the future."
)
if
args
.
transformer_encoder_selfattn_layer_type
==
"rel_selfattn"
:
args
.
transformer_encoder_selfattn_layer_type
=
"legacy_rel_selfattn"
logging
.
warning
(
"Using legacy_rel_selfattn and it will be deprecated in the future."
)
return
args
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/contextual_block_encoder_layer.py
0 → 100644
View file @
60a2c57a
# -*- coding: utf-8 -*-
"""
Created on Sat Aug 21 16:57:31 2021.
@author: Keqi Deng (UCAS)
"""
import
torch
from
torch
import
nn
from
espnet.nets.pytorch_backend.transformer.layer_norm
import
LayerNorm
class
ContextualBlockEncoderLayer
(
nn
.
Module
):
"""Contexutal Block Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
total_layer_num (int): Total number of layers
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def
__init__
(
self
,
size
,
self_attn
,
feed_forward
,
feed_forward_macaron
,
conv_module
,
dropout_rate
,
total_layer_num
,
normalize_before
=
True
,
concat_after
=
False
,
):
"""Construct an EncoderLayer object."""
super
(
ContextualBlockEncoderLayer
,
self
).
__init__
()
self
.
self_attn
=
self_attn
self
.
feed_forward
=
feed_forward
self
.
feed_forward_macaron
=
feed_forward_macaron
self
.
conv_module
=
conv_module
self
.
norm1
=
LayerNorm
(
size
)
self
.
norm2
=
LayerNorm
(
size
)
if
feed_forward_macaron
is
not
None
:
self
.
norm_ff_macaron
=
LayerNorm
(
size
)
self
.
ff_scale
=
0.5
else
:
self
.
ff_scale
=
1.0
if
self
.
conv_module
is
not
None
:
self
.
norm_conv
=
LayerNorm
(
size
)
# for the CNN module
self
.
norm_final
=
LayerNorm
(
size
)
# for the final output of the block
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
self
.
total_layer_num
=
total_layer_num
if
self
.
concat_after
:
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
def
forward
(
self
,
x
,
mask
,
infer_mode
=
False
,
past_ctx
=
None
,
next_ctx
=
None
,
is_short_segment
=
False
,
layer_idx
=
0
,
cache
=
None
,
):
"""Calculate forward propagation."""
if
self
.
training
or
not
infer_mode
:
return
self
.
forward_train
(
x
,
mask
,
past_ctx
,
next_ctx
,
layer_idx
,
cache
)
else
:
return
self
.
forward_infer
(
x
,
mask
,
past_ctx
,
next_ctx
,
is_short_segment
,
layer_idx
,
cache
)
def
forward_train
(
self
,
x
,
mask
,
past_ctx
=
None
,
next_ctx
=
None
,
layer_idx
=
0
,
cache
=
None
):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
past_ctx (torch.Tensor): Previous contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
cur_ctx (torch.Tensor): Current contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
layer_idx (int): layer index number
"""
nbatch
=
x
.
size
(
0
)
nblock
=
x
.
size
(
1
)
if
past_ctx
is
not
None
:
if
next_ctx
is
None
:
# store all context vectors in one tensor
next_ctx
=
past_ctx
.
new_zeros
(
nbatch
,
nblock
,
self
.
total_layer_num
,
x
.
size
(
-
1
)
)
else
:
x
[:,
:,
0
]
=
past_ctx
[:,
:,
layer_idx
]
# reshape ( nbatch, nblock, block_size + 2, dim )
# -> ( nbatch * nblock, block_size + 2, dim )
x
=
x
.
view
(
-
1
,
x
.
size
(
-
2
),
x
.
size
(
-
1
))
if
mask
is
not
None
:
mask
=
mask
.
view
(
-
1
,
mask
.
size
(
-
2
),
mask
.
size
(
-
1
))
# whether to use macaron style
if
self
.
feed_forward_macaron
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward_macaron
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
if
cache
is
None
:
x_q
=
x
else
:
assert
cache
.
shape
==
(
x
.
shape
[
0
],
x
.
shape
[
1
]
-
1
,
self
.
size
)
x_q
=
x
[:,
-
1
:,
:]
residual
=
residual
[:,
-
1
:,
:]
mask
=
None
if
mask
is
None
else
mask
[:,
-
1
:,
:]
if
self
.
concat_after
:
x_concat
=
torch
.
cat
((
x
,
self
.
self_attn
(
x_q
,
x
,
x
,
mask
)),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
self
.
self_attn
(
x_q
,
x
,
x
,
mask
))
if
not
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
# convolution module
if
self
.
conv_module
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
x
=
residual
+
self
.
dropout
(
self
.
conv_module
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
if
self
.
conv_module
is
not
None
:
x
=
self
.
norm_final
(
x
)
if
cache
is
not
None
:
x
=
torch
.
cat
([
cache
,
x
],
dim
=
1
)
layer_idx
+=
1
# reshape ( nbatch * nblock, block_size + 2, dim )
# -> ( nbatch, nblock, block_size + 2, dim )
x
=
x
.
view
(
nbatch
,
-
1
,
x
.
size
(
-
2
),
x
.
size
(
-
1
)).
squeeze
(
1
)
if
mask
is
not
None
:
mask
=
mask
.
view
(
nbatch
,
-
1
,
mask
.
size
(
-
2
),
mask
.
size
(
-
1
)).
squeeze
(
1
)
if
next_ctx
is
not
None
and
layer_idx
<
self
.
total_layer_num
:
next_ctx
[:,
0
,
layer_idx
,
:]
=
x
[:,
0
,
-
1
,
:]
next_ctx
[:,
1
:,
layer_idx
,
:]
=
x
[:,
0
:
-
1
,
-
1
,
:]
return
x
,
mask
,
False
,
next_ctx
,
next_ctx
,
False
,
layer_idx
def
forward_infer
(
self
,
x
,
mask
,
past_ctx
=
None
,
next_ctx
=
None
,
is_short_segment
=
False
,
layer_idx
=
0
,
cache
=
None
,
):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
past_ctx (torch.Tensor): Previous contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, 1, time).
cur_ctx (torch.Tensor): Current contexutal vector
next_ctx (torch.Tensor): Next contexutal vector
layer_idx (int): layer index number
"""
nbatch
=
x
.
size
(
0
)
nblock
=
x
.
size
(
1
)
# if layer_idx == 0, next_ctx has to be None
if
layer_idx
==
0
:
assert
next_ctx
is
None
next_ctx
=
x
.
new_zeros
(
nbatch
,
self
.
total_layer_num
,
x
.
size
(
-
1
))
# reshape ( nbatch, nblock, block_size + 2, dim )
# -> ( nbatch * nblock, block_size + 2, dim )
x
=
x
.
view
(
-
1
,
x
.
size
(
-
2
),
x
.
size
(
-
1
))
if
mask
is
not
None
:
mask
=
mask
.
view
(
-
1
,
mask
.
size
(
-
2
),
mask
.
size
(
-
1
))
# whether to use macaron style
if
self
.
feed_forward_macaron
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward_macaron
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
if
cache
is
None
:
x_q
=
x
else
:
assert
cache
.
shape
==
(
x
.
shape
[
0
],
x
.
shape
[
1
]
-
1
,
self
.
size
)
x_q
=
x
[:,
-
1
:,
:]
residual
=
residual
[:,
-
1
:,
:]
mask
=
None
if
mask
is
None
else
mask
[:,
-
1
:,
:]
if
self
.
concat_after
:
x_concat
=
torch
.
cat
((
x
,
self
.
self_attn
(
x_q
,
x
,
x
,
mask
)),
dim
=-
1
)
x
=
residual
+
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
self
.
dropout
(
self
.
self_attn
(
x_q
,
x
,
x
,
mask
))
if
not
self
.
normalize_before
:
x
=
self
.
norm1
(
x
)
# convolution module
if
self
.
conv_module
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
x
=
residual
+
self
.
dropout
(
self
.
conv_module
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm2
(
x
)
if
self
.
conv_module
is
not
None
:
x
=
self
.
norm_final
(
x
)
if
cache
is
not
None
:
x
=
torch
.
cat
([
cache
,
x
],
dim
=
1
)
# reshape ( nbatch * nblock, block_size + 2, dim )
# -> ( nbatch, nblock, block_size + 2, dim )
x
=
x
.
view
(
nbatch
,
nblock
,
x
.
size
(
-
2
),
x
.
size
(
-
1
))
if
mask
is
not
None
:
mask
=
mask
.
view
(
nbatch
,
nblock
,
mask
.
size
(
-
2
),
mask
.
size
(
-
1
))
# Propagete context information (the last frame of each block)
# to the first frame
# of the next block
if
not
is_short_segment
:
if
past_ctx
is
None
:
# First block of an utterance
x
[:,
0
,
0
,
:]
=
x
[:,
0
,
-
1
,
:]
else
:
x
[:,
0
,
0
,
:]
=
past_ctx
[:,
layer_idx
,
:]
if
nblock
>
1
:
x
[:,
1
:,
0
,
:]
=
x
[:,
0
:
-
1
,
-
1
,
:]
next_ctx
[:,
layer_idx
,
:]
=
x
[:,
-
1
,
-
1
,
:]
else
:
next_ctx
=
None
return
x
,
mask
,
True
,
past_ctx
,
next_ctx
,
is_short_segment
,
layer_idx
+
1
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/convolution.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""ConvolutionModule definition."""
from
torch
import
nn
class
ConvolutionModule
(
nn
.
Module
):
"""ConvolutionModule in Conformer model.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
"""
def
__init__
(
self
,
channels
,
kernel_size
,
activation
=
nn
.
ReLU
(),
bias
=
True
):
"""Construct an ConvolutionModule object."""
super
(
ConvolutionModule
,
self
).
__init__
()
# kernerl_size should be a odd number for 'SAME' padding
assert
(
kernel_size
-
1
)
%
2
==
0
self
.
pointwise_conv1
=
nn
.
Conv1d
(
channels
,
2
*
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
bias
,
)
self
.
depthwise_conv
=
nn
.
Conv1d
(
channels
,
channels
,
kernel_size
,
stride
=
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
channels
,
bias
=
bias
,
)
self
.
norm
=
nn
.
BatchNorm1d
(
channels
)
self
.
pointwise_conv2
=
nn
.
Conv1d
(
channels
,
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
bias
,
)
self
.
activation
=
activation
def
forward
(
self
,
x
):
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x
=
x
.
transpose
(
1
,
2
)
# GLU mechanism
x
=
self
.
pointwise_conv1
(
x
)
# (batch, 2*channel, dim)
x
=
nn
.
functional
.
glu
(
x
,
dim
=
1
)
# (batch, channel, dim)
# 1D Depthwise Conv
x
=
self
.
depthwise_conv
(
x
)
x
=
self
.
activation
(
self
.
norm
(
x
))
x
=
self
.
pointwise_conv2
(
x
)
return
x
.
transpose
(
1
,
2
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/encoder.py
0 → 100644
View file @
60a2c57a
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder definition."""
import
logging
import
torch
from
espnet.nets.pytorch_backend.conformer.convolution
import
ConvolutionModule
from
espnet.nets.pytorch_backend.conformer.encoder_layer
import
EncoderLayer
from
espnet.nets.pytorch_backend.nets_utils
import
get_activation
from
espnet.nets.pytorch_backend.transducer.vgg2l
import
VGG2L
from
espnet.nets.pytorch_backend.transformer.attention
import
(
LegacyRelPositionMultiHeadedAttention
,
MultiHeadedAttention
,
RelPositionMultiHeadedAttention
,
)
from
espnet.nets.pytorch_backend.transformer.embedding
import
(
LegacyRelPositionalEncoding
,
PositionalEncoding
,
RelPositionalEncoding
,
ScaledPositionalEncoding
,
)
from
espnet.nets.pytorch_backend.transformer.layer_norm
import
LayerNorm
from
espnet.nets.pytorch_backend.transformer.multi_layer_conv
import
(
Conv1dLinear
,
MultiLayeredConv1d
,
)
from
espnet.nets.pytorch_backend.transformer.positionwise_feed_forward
import
(
PositionwiseFeedForward
,
)
from
espnet.nets.pytorch_backend.transformer.repeat
import
repeat
from
espnet.nets.pytorch_backend.transformer.subsampling
import
Conv2dSubsampling
class
Encoder
(
torch
.
nn
.
Module
):
"""Conformer encoder module.
Args:
idim (int): Input dimension.
attention_dim (int): Dimension of attention.
attention_heads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
attention_dropout_rate (float): Dropout rate in attention.
input_layer (Union[str, torch.nn.Module]): Input layer type.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
macaron_style (bool): Whether to use macaron style for positionwise layer.
pos_enc_layer_type (str): Encoder positional encoding layer type.
selfattention_layer_type (str): Encoder attention layer type.
activation_type (str): Encoder activation function type.
use_cnn_module (bool): Whether to use convolution module.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
intermediate_layers (Union[List[int], None]): indices of intermediate CTC layer.
indices start from 1.
if not None, intermediate outputs are returned (which changes return type
signature.)
"""
def
__init__
(
self
,
idim
,
attention_dim
=
256
,
attention_heads
=
4
,
linear_units
=
2048
,
num_blocks
=
6
,
dropout_rate
=
0.1
,
positional_dropout_rate
=
0.1
,
attention_dropout_rate
=
0.0
,
input_layer
=
"conv2d"
,
normalize_before
=
True
,
concat_after
=
False
,
positionwise_layer_type
=
"linear"
,
positionwise_conv_kernel_size
=
1
,
macaron_style
=
False
,
pos_enc_layer_type
=
"abs_pos"
,
selfattention_layer_type
=
"selfattn"
,
activation_type
=
"swish"
,
use_cnn_module
=
False
,
zero_triu
=
False
,
cnn_module_kernel
=
31
,
padding_idx
=-
1
,
stochastic_depth_rate
=
0.0
,
intermediate_layers
=
None
,
ctc_softmax
=
None
,
conditioning_layer_dim
=
None
,
):
"""Construct an Encoder object."""
super
(
Encoder
,
self
).
__init__
()
activation
=
get_activation
(
activation_type
)
if
pos_enc_layer_type
==
"abs_pos"
:
pos_enc_class
=
PositionalEncoding
elif
pos_enc_layer_type
==
"scaled_abs_pos"
:
pos_enc_class
=
ScaledPositionalEncoding
elif
pos_enc_layer_type
==
"rel_pos"
:
assert
selfattention_layer_type
==
"rel_selfattn"
pos_enc_class
=
RelPositionalEncoding
elif
pos_enc_layer_type
==
"legacy_rel_pos"
:
pos_enc_class
=
LegacyRelPositionalEncoding
assert
selfattention_layer_type
==
"legacy_rel_selfattn"
else
:
raise
ValueError
(
"unknown pos_enc_layer: "
+
pos_enc_layer_type
)
self
.
conv_subsampling_factor
=
1
if
input_layer
==
"linear"
:
self
.
embed
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
idim
,
attention_dim
),
torch
.
nn
.
LayerNorm
(
attention_dim
),
torch
.
nn
.
Dropout
(
dropout_rate
),
pos_enc_class
(
attention_dim
,
positional_dropout_rate
),
)
elif
input_layer
==
"conv2d"
:
self
.
embed
=
Conv2dSubsampling
(
idim
,
attention_dim
,
dropout_rate
,
pos_enc_class
(
attention_dim
,
positional_dropout_rate
),
)
self
.
conv_subsampling_factor
=
4
elif
input_layer
==
"vgg2l"
:
self
.
embed
=
VGG2L
(
idim
,
attention_dim
)
self
.
conv_subsampling_factor
=
4
elif
input_layer
==
"embed"
:
self
.
embed
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Embedding
(
idim
,
attention_dim
,
padding_idx
=
padding_idx
),
pos_enc_class
(
attention_dim
,
positional_dropout_rate
),
)
elif
isinstance
(
input_layer
,
torch
.
nn
.
Module
):
self
.
embed
=
torch
.
nn
.
Sequential
(
input_layer
,
pos_enc_class
(
attention_dim
,
positional_dropout_rate
),
)
elif
input_layer
is
None
:
self
.
embed
=
torch
.
nn
.
Sequential
(
pos_enc_class
(
attention_dim
,
positional_dropout_rate
)
)
else
:
raise
ValueError
(
"unknown input_layer: "
+
input_layer
)
self
.
normalize_before
=
normalize_before
# self-attention module definition
if
selfattention_layer_type
==
"selfattn"
:
logging
.
info
(
"encoder self-attention layer type = self-attention"
)
encoder_selfattn_layer
=
MultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
attention_dropout_rate
,
)
elif
selfattention_layer_type
==
"legacy_rel_selfattn"
:
assert
pos_enc_layer_type
==
"legacy_rel_pos"
encoder_selfattn_layer
=
LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
attention_dropout_rate
,
)
elif
selfattention_layer_type
==
"rel_selfattn"
:
logging
.
info
(
"encoder self-attention layer type = relative self-attention"
)
assert
pos_enc_layer_type
==
"rel_pos"
encoder_selfattn_layer
=
RelPositionMultiHeadedAttention
encoder_selfattn_layer_args
=
(
attention_heads
,
attention_dim
,
attention_dropout_rate
,
zero_triu
,
)
else
:
raise
ValueError
(
"unknown encoder_attn_layer: "
+
selfattention_layer_type
)
# feed-forward module definition
if
positionwise_layer_type
==
"linear"
:
positionwise_layer
=
PositionwiseFeedForward
positionwise_layer_args
=
(
attention_dim
,
linear_units
,
dropout_rate
,
activation
,
)
elif
positionwise_layer_type
==
"conv1d"
:
positionwise_layer
=
MultiLayeredConv1d
positionwise_layer_args
=
(
attention_dim
,
linear_units
,
positionwise_conv_kernel_size
,
dropout_rate
,
)
elif
positionwise_layer_type
==
"conv1d-linear"
:
positionwise_layer
=
Conv1dLinear
positionwise_layer_args
=
(
attention_dim
,
linear_units
,
positionwise_conv_kernel_size
,
dropout_rate
,
)
else
:
raise
NotImplementedError
(
"Support only linear or conv1d."
)
# convolution module definition
convolution_layer
=
ConvolutionModule
convolution_layer_args
=
(
attention_dim
,
cnn_module_kernel
,
activation
)
self
.
encoders
=
repeat
(
num_blocks
,
lambda
lnum
:
EncoderLayer
(
attention_dim
,
encoder_selfattn_layer
(
*
encoder_selfattn_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
),
positionwise_layer
(
*
positionwise_layer_args
)
if
macaron_style
else
None
,
convolution_layer
(
*
convolution_layer_args
)
if
use_cnn_module
else
None
,
dropout_rate
,
normalize_before
,
concat_after
,
stochastic_depth_rate
*
float
(
1
+
lnum
)
/
num_blocks
,
),
)
if
self
.
normalize_before
:
self
.
after_norm
=
LayerNorm
(
attention_dim
)
self
.
intermediate_layers
=
intermediate_layers
self
.
use_conditioning
=
True
if
ctc_softmax
is
not
None
else
False
if
self
.
use_conditioning
:
self
.
ctc_softmax
=
ctc_softmax
self
.
conditioning_layer
=
torch
.
nn
.
Linear
(
conditioning_layer_dim
,
attention_dim
)
def
forward
(
self
,
xs
,
masks
):
"""Encode input sequence.
Args:
xs (torch.Tensor): Input tensor (#batch, time, idim).
masks (torch.Tensor): Mask tensor (#batch, 1, time).
Returns:
torch.Tensor: Output tensor (#batch, time, attention_dim).
torch.Tensor: Mask tensor (#batch, 1, time).
"""
if
isinstance
(
self
.
embed
,
(
Conv2dSubsampling
,
VGG2L
)):
xs
,
masks
=
self
.
embed
(
xs
,
masks
)
else
:
xs
=
self
.
embed
(
xs
)
if
self
.
intermediate_layers
is
None
:
xs
,
masks
=
self
.
encoders
(
xs
,
masks
)
else
:
intermediate_outputs
=
[]
for
layer_idx
,
encoder_layer
in
enumerate
(
self
.
encoders
):
xs
,
masks
=
encoder_layer
(
xs
,
masks
)
if
(
self
.
intermediate_layers
is
not
None
and
layer_idx
+
1
in
self
.
intermediate_layers
):
# intermediate branches also require normalization.
encoder_output
=
xs
if
isinstance
(
encoder_output
,
tuple
):
encoder_output
=
encoder_output
[
0
]
if
self
.
normalize_before
:
encoder_output
=
self
.
after_norm
(
encoder_output
)
intermediate_outputs
.
append
(
encoder_output
)
if
self
.
use_conditioning
:
intermediate_result
=
self
.
ctc_softmax
(
encoder_output
)
if
isinstance
(
xs
,
tuple
):
x
,
pos_emb
=
xs
[
0
],
xs
[
1
]
x
=
x
+
self
.
conditioning_layer
(
intermediate_result
)
xs
=
(
x
,
pos_emb
)
else
:
xs
=
xs
+
self
.
conditioning_layer
(
intermediate_result
)
if
isinstance
(
xs
,
tuple
):
xs
=
xs
[
0
]
if
self
.
normalize_before
:
xs
=
self
.
after_norm
(
xs
)
if
self
.
intermediate_layers
is
not
None
:
return
xs
,
masks
,
intermediate_outputs
return
xs
,
masks
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/encoder_layer.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder self-attention layer definition."""
import
torch
from
torch
import
nn
from
espnet.nets.pytorch_backend.transformer.layer_norm
import
LayerNorm
class
EncoderLayer
(
nn
.
Module
):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
stochastic_depth_rate (float): Proability to skip this layer.
During training, the layer may skip residual computation and return input
as-is with given probability.
"""
def
__init__
(
self
,
size
,
self_attn
,
feed_forward
,
feed_forward_macaron
,
conv_module
,
dropout_rate
,
normalize_before
=
True
,
concat_after
=
False
,
stochastic_depth_rate
=
0.0
,
):
"""Construct an EncoderLayer object."""
super
(
EncoderLayer
,
self
).
__init__
()
self
.
self_attn
=
self_attn
self
.
feed_forward
=
feed_forward
self
.
feed_forward_macaron
=
feed_forward_macaron
self
.
conv_module
=
conv_module
self
.
norm_ff
=
LayerNorm
(
size
)
# for the FNN module
self
.
norm_mha
=
LayerNorm
(
size
)
# for the MHA module
if
feed_forward_macaron
is
not
None
:
self
.
norm_ff_macaron
=
LayerNorm
(
size
)
self
.
ff_scale
=
0.5
else
:
self
.
ff_scale
=
1.0
if
self
.
conv_module
is
not
None
:
self
.
norm_conv
=
LayerNorm
(
size
)
# for the CNN module
self
.
norm_final
=
LayerNorm
(
size
)
# for the final output of the block
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
size
=
size
self
.
normalize_before
=
normalize_before
self
.
concat_after
=
concat_after
if
self
.
concat_after
:
self
.
concat_linear
=
nn
.
Linear
(
size
+
size
,
size
)
self
.
stochastic_depth_rate
=
stochastic_depth_rate
def
forward
(
self
,
x_input
,
mask
,
cache
=
None
):
"""Compute encoded features.
Args:
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
- w/o pos emb: Tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, 1, time).
"""
if
isinstance
(
x_input
,
tuple
):
x
,
pos_emb
=
x_input
[
0
],
x_input
[
1
]
else
:
x
,
pos_emb
=
x_input
,
None
skip_layer
=
False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff
=
1.0
if
self
.
training
and
self
.
stochastic_depth_rate
>
0
:
skip_layer
=
torch
.
rand
(
1
).
item
()
<
self
.
stochastic_depth_rate
stoch_layer_coeff
=
1.0
/
(
1
-
self
.
stochastic_depth_rate
)
if
skip_layer
:
if
cache
is
not
None
:
x
=
torch
.
cat
([
cache
,
x
],
dim
=
1
)
if
pos_emb
is
not
None
:
return
(
x
,
pos_emb
),
mask
return
x
,
mask
# whether to use macaron style
if
self
.
feed_forward_macaron
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
x
=
residual
+
stoch_layer_coeff
*
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward_macaron
(
x
)
)
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff_macaron
(
x
)
# multi-headed self-attention module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_mha
(
x
)
if
cache
is
None
:
x_q
=
x
else
:
assert
cache
.
shape
==
(
x
.
shape
[
0
],
x
.
shape
[
1
]
-
1
,
self
.
size
)
x_q
=
x
[:,
-
1
:,
:]
residual
=
residual
[:,
-
1
:,
:]
mask
=
None
if
mask
is
None
else
mask
[:,
-
1
:,
:]
if
pos_emb
is
not
None
:
x_att
=
self
.
self_attn
(
x_q
,
x
,
x
,
pos_emb
,
mask
)
else
:
x_att
=
self
.
self_attn
(
x_q
,
x
,
x
,
mask
)
if
self
.
concat_after
:
x_concat
=
torch
.
cat
((
x
,
x_att
),
dim
=-
1
)
x
=
residual
+
stoch_layer_coeff
*
self
.
concat_linear
(
x_concat
)
else
:
x
=
residual
+
stoch_layer_coeff
*
self
.
dropout
(
x_att
)
if
not
self
.
normalize_before
:
x
=
self
.
norm_mha
(
x
)
# convolution module
if
self
.
conv_module
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
x
=
residual
+
stoch_layer_coeff
*
self
.
dropout
(
self
.
conv_module
(
x
))
if
not
self
.
normalize_before
:
x
=
self
.
norm_conv
(
x
)
# feed forward module
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
norm_ff
(
x
)
x
=
residual
+
stoch_layer_coeff
*
self
.
ff_scale
*
self
.
dropout
(
self
.
feed_forward
(
x
)
)
if
not
self
.
normalize_before
:
x
=
self
.
norm_ff
(
x
)
if
self
.
conv_module
is
not
None
:
x
=
self
.
norm_final
(
x
)
if
cache
is
not
None
:
x
=
torch
.
cat
([
cache
,
x
],
dim
=
1
)
if
pos_emb
is
not
None
:
return
(
x
,
pos_emb
),
mask
return
x
,
mask
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/conformer/swish.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Swish() activation function for Conformer."""
import
torch
class
Swish
(
torch
.
nn
.
Module
):
"""Construct an Swish object."""
def
forward
(
self
,
x
):
"""Return Swich activation function."""
return
x
*
torch
.
sigmoid
(
x
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/ctc.py
0 → 100644
View file @
60a2c57a
import
logging
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
packaging.version
import
parse
as
V
from
espnet.nets.pytorch_backend.nets_utils
import
to_device
class
CTC
(
torch
.
nn
.
Module
):
"""CTC module
:param int odim: dimension of outputs
:param int eprojs: number of encoder projection units
:param float dropout_rate: dropout rate (0.0 ~ 1.0)
:param str ctc_type: builtin
:param bool reduce: reduce the CTC loss into a scalar
"""
def
__init__
(
self
,
odim
,
eprojs
,
dropout_rate
,
ctc_type
=
"builtin"
,
reduce
=
True
):
super
().
__init__
()
self
.
dropout_rate
=
dropout_rate
self
.
loss
=
None
self
.
ctc_lo
=
torch
.
nn
.
Linear
(
eprojs
,
odim
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout_rate
)
self
.
probs
=
None
# for visualization
# In case of Pytorch >= 1.7.0, CTC will be always builtin
self
.
ctc_type
=
ctc_type
if
V
(
torch
.
__version__
)
<
V
(
"1.7.0"
)
else
"builtin"
if
ctc_type
!=
self
.
ctc_type
:
logging
.
warning
(
f
"CTC was set to
{
self
.
ctc_type
}
due to PyTorch version."
)
if
self
.
ctc_type
==
"builtin"
:
reduction_type
=
"sum"
if
reduce
else
"none"
self
.
ctc_loss
=
torch
.
nn
.
CTCLoss
(
reduction
=
reduction_type
,
zero_infinity
=
True
)
elif
self
.
ctc_type
==
"cudnnctc"
:
reduction_type
=
"sum"
if
reduce
else
"none"
self
.
ctc_loss
=
torch
.
nn
.
CTCLoss
(
reduction
=
reduction_type
)
elif
self
.
ctc_type
==
"gtnctc"
:
from
espnet.nets.pytorch_backend.gtn_ctc
import
GTNCTCLossFunction
self
.
ctc_loss
=
GTNCTCLossFunction
.
apply
else
:
raise
ValueError
(
'ctc_type must be "builtin" or "gtnctc": {}'
.
format
(
self
.
ctc_type
)
)
self
.
ignore_id
=
-
1
self
.
reduce
=
reduce
def
loss_fn
(
self
,
th_pred
,
th_target
,
th_ilen
,
th_olen
):
if
self
.
ctc_type
in
[
"builtin"
,
"cudnnctc"
]:
th_pred
=
th_pred
.
log_softmax
(
2
)
# Use the deterministic CuDNN implementation of CTC loss to avoid
# [issue#17798](https://github.com/pytorch/pytorch/issues/17798)
with
torch
.
backends
.
cudnn
.
flags
(
deterministic
=
True
):
loss
=
self
.
ctc_loss
(
th_pred
,
th_target
,
th_ilen
,
th_olen
)
# Batch-size average
loss
=
loss
/
th_pred
.
size
(
1
)
return
loss
elif
self
.
ctc_type
==
"gtnctc"
:
targets
=
[
t
.
tolist
()
for
t
in
th_target
]
log_probs
=
torch
.
nn
.
functional
.
log_softmax
(
th_pred
,
dim
=
2
)
return
self
.
ctc_loss
(
log_probs
,
targets
,
th_ilen
,
0
,
"none"
)
else
:
raise
NotImplementedError
def
forward
(
self
,
hs_pad
,
hlens
,
ys_pad
):
"""CTC forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, Lmax)
:return: ctc loss value
:rtype: torch.Tensor
"""
# TODO(kan-bayashi): need to make more smart way
ys
=
[
y
[
y
!=
self
.
ignore_id
]
for
y
in
ys_pad
]
# parse padded ys
# zero padding for hs
ys_hat
=
self
.
ctc_lo
(
self
.
dropout
(
hs_pad
))
if
self
.
ctc_type
!=
"gtnctc"
:
ys_hat
=
ys_hat
.
transpose
(
0
,
1
)
if
self
.
ctc_type
==
"builtin"
:
olens
=
to_device
(
ys_hat
,
torch
.
LongTensor
([
len
(
s
)
for
s
in
ys
]))
hlens
=
hlens
.
long
()
ys_pad
=
torch
.
cat
(
ys
)
# without this the code breaks for asr_mix
self
.
loss
=
self
.
loss_fn
(
ys_hat
,
ys_pad
,
hlens
,
olens
)
else
:
self
.
loss
=
None
hlens
=
torch
.
from_numpy
(
np
.
fromiter
(
hlens
,
dtype
=
np
.
int32
))
olens
=
torch
.
from_numpy
(
np
.
fromiter
((
x
.
size
(
0
)
for
x
in
ys
),
dtype
=
np
.
int32
)
)
# zero padding for ys
ys_true
=
torch
.
cat
(
ys
).
cpu
().
int
()
# batch x olen
# get ctc loss
# expected shape of seqLength x batchSize x alphabet_size
dtype
=
ys_hat
.
dtype
if
self
.
ctc_type
==
"cudnnctc"
:
# use GPU when using the cuDNN implementation
ys_true
=
to_device
(
hs_pad
,
ys_true
)
if
self
.
ctc_type
==
"gtnctc"
:
# keep as list for gtn
ys_true
=
ys
self
.
loss
=
to_device
(
hs_pad
,
self
.
loss_fn
(
ys_hat
,
ys_true
,
hlens
,
olens
)
).
to
(
dtype
=
dtype
)
# get length info
logging
.
info
(
self
.
__class__
.
__name__
+
" input lengths: "
+
""
.
join
(
str
(
hlens
).
split
(
"
\n
"
))
)
logging
.
info
(
self
.
__class__
.
__name__
+
" output lengths: "
+
""
.
join
(
str
(
olens
).
split
(
"
\n
"
))
)
if
self
.
reduce
:
self
.
loss
=
self
.
loss
.
sum
()
logging
.
info
(
"ctc loss:"
+
str
(
float
(
self
.
loss
)))
return
self
.
loss
def
softmax
(
self
,
hs_pad
):
"""softmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
:return: log softmax applied 3d tensor (B, Tmax, odim)
:rtype: torch.Tensor
"""
self
.
probs
=
F
.
softmax
(
self
.
ctc_lo
(
hs_pad
),
dim
=
2
)
return
self
.
probs
def
log_softmax
(
self
,
hs_pad
):
"""log_softmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
:return: log softmax applied 3d tensor (B, Tmax, odim)
:rtype: torch.Tensor
"""
return
F
.
log_softmax
(
self
.
ctc_lo
(
hs_pad
),
dim
=
2
)
def
argmax
(
self
,
hs_pad
):
"""argmax of frame activations
:param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
:return: argmax applied 2d tensor (B, Tmax)
:rtype: torch.Tensor
"""
return
torch
.
argmax
(
self
.
ctc_lo
(
hs_pad
),
dim
=
2
)
def
forced_align
(
self
,
h
,
y
,
blank_id
=
0
):
"""forced alignment.
:param torch.Tensor h: hidden state sequence, 2d tensor (T, D)
:param torch.Tensor y: id sequence tensor 1d tensor (L)
:param int y: blank symbol index
:return: best alignment results
:rtype: list
"""
def
interpolate_blank
(
label
,
blank_id
=
0
):
"""Insert blank token between every two label token."""
label
=
np
.
expand_dims
(
label
,
1
)
blanks
=
np
.
zeros
((
label
.
shape
[
0
],
1
),
dtype
=
np
.
int64
)
+
blank_id
label
=
np
.
concatenate
([
blanks
,
label
],
axis
=
1
)
label
=
label
.
reshape
(
-
1
)
label
=
np
.
append
(
label
,
label
[
0
])
return
label
lpz
=
self
.
log_softmax
(
h
)
lpz
=
lpz
.
squeeze
(
0
)
y_int
=
interpolate_blank
(
y
,
blank_id
)
logdelta
=
np
.
zeros
((
lpz
.
size
(
0
),
len
(
y_int
)))
-
100000000000.0
# log of zero
state_path
=
(
np
.
zeros
((
lpz
.
size
(
0
),
len
(
y_int
)),
dtype
=
np
.
int16
)
-
1
)
# state path
logdelta
[
0
,
0
]
=
lpz
[
0
][
y_int
[
0
]]
logdelta
[
0
,
1
]
=
lpz
[
0
][
y_int
[
1
]]
for
t
in
range
(
1
,
lpz
.
size
(
0
)):
for
s
in
range
(
len
(
y_int
)):
if
y_int
[
s
]
==
blank_id
or
s
<
2
or
y_int
[
s
]
==
y_int
[
s
-
2
]:
candidates
=
np
.
array
([
logdelta
[
t
-
1
,
s
],
logdelta
[
t
-
1
,
s
-
1
]])
prev_state
=
[
s
,
s
-
1
]
else
:
candidates
=
np
.
array
(
[
logdelta
[
t
-
1
,
s
],
logdelta
[
t
-
1
,
s
-
1
],
logdelta
[
t
-
1
,
s
-
2
],
]
)
prev_state
=
[
s
,
s
-
1
,
s
-
2
]
logdelta
[
t
,
s
]
=
np
.
max
(
candidates
)
+
lpz
[
t
][
y_int
[
s
]]
state_path
[
t
,
s
]
=
prev_state
[
np
.
argmax
(
candidates
)]
state_seq
=
-
1
*
np
.
ones
((
lpz
.
size
(
0
),
1
),
dtype
=
np
.
int16
)
candidates
=
np
.
array
(
[
logdelta
[
-
1
,
len
(
y_int
)
-
1
],
logdelta
[
-
1
,
len
(
y_int
)
-
2
]]
)
prev_state
=
[
len
(
y_int
)
-
1
,
len
(
y_int
)
-
2
]
state_seq
[
-
1
]
=
prev_state
[
np
.
argmax
(
candidates
)]
for
t
in
range
(
lpz
.
size
(
0
)
-
2
,
-
1
,
-
1
):
state_seq
[
t
]
=
state_path
[
t
+
1
,
state_seq
[
t
+
1
,
0
]]
output_state_seq
=
[]
for
t
in
range
(
0
,
lpz
.
size
(
0
)):
output_state_seq
.
append
(
y_int
[
state_seq
[
t
,
0
]])
return
output_state_seq
def
ctc_for
(
args
,
odim
,
reduce
=
True
):
"""Returns the CTC module for the given args and output dimension
:param Namespace args: the program args
:param int odim : The output dimension
:param bool reduce : return the CTC loss in a scalar
:return: the corresponding CTC module
"""
num_encs
=
getattr
(
args
,
"num_encs"
,
1
)
# use getattr to keep compatibility
if
num_encs
==
1
:
# compatible with single encoder asr mode
return
CTC
(
odim
,
args
.
eprojs
,
args
.
dropout_rate
,
ctc_type
=
args
.
ctc_type
,
reduce
=
reduce
)
elif
num_encs
>=
1
:
ctcs_list
=
torch
.
nn
.
ModuleList
()
if
args
.
share_ctc
:
# use dropout_rate of the first encoder
ctc
=
CTC
(
odim
,
args
.
eprojs
,
args
.
dropout_rate
[
0
],
ctc_type
=
args
.
ctc_type
,
reduce
=
reduce
,
)
ctcs_list
.
append
(
ctc
)
else
:
for
idx
in
range
(
num_encs
):
ctc
=
CTC
(
odim
,
args
.
eprojs
,
args
.
dropout_rate
[
idx
],
ctc_type
=
args
.
ctc_type
,
reduce
=
reduce
,
)
ctcs_list
.
append
(
ctc
)
return
ctcs_list
else
:
raise
ValueError
(
"Number of encoders needs to be more than one. {}"
.
format
(
num_encs
)
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/e2e_asr.py
0 → 100644
View file @
60a2c57a
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""RNN sequence-to-sequence speech recognition model (pytorch)."""
import
argparse
import
logging
import
math
import
os
from
itertools
import
groupby
import
chainer
import
numpy
as
np
import
torch
from
chainer
import
reporter
from
espnet.nets.asr_interface
import
ASRInterface
from
espnet.nets.e2e_asr_common
import
label_smoothing_dist
from
espnet.nets.pytorch_backend.ctc
import
ctc_for
from
espnet.nets.pytorch_backend.frontends.feature_transform
import
(
# noqa: H301
feature_transform_for
,
)
from
espnet.nets.pytorch_backend.frontends.frontend
import
frontend_for
from
espnet.nets.pytorch_backend.initialization
import
(
lecun_normal_init_parameters
,
set_forget_bias_to_one
,
)
from
espnet.nets.pytorch_backend.nets_utils
import
(
get_subsample
,
pad_list
,
to_device
,
to_torch_tensor
,
)
from
espnet.nets.pytorch_backend.rnn.argument
import
(
# noqa: H301
add_arguments_rnn_attention_common
,
add_arguments_rnn_decoder_common
,
add_arguments_rnn_encoder_common
,
)
from
espnet.nets.pytorch_backend.rnn.attentions
import
att_for
from
espnet.nets.pytorch_backend.rnn.decoders
import
decoder_for
from
espnet.nets.pytorch_backend.rnn.encoders
import
encoder_for
from
espnet.nets.scorers.ctc
import
CTCPrefixScorer
from
espnet.utils.fill_missing_args
import
fill_missing_args
CTC_LOSS_THRESHOLD
=
10000
class
Reporter
(
chainer
.
Chain
):
"""A chainer reporter wrapper."""
def
report
(
self
,
loss_ctc
,
loss_att
,
acc
,
cer_ctc
,
cer
,
wer
,
mtl_loss
):
"""Report at every step."""
reporter
.
report
({
"loss_ctc"
:
loss_ctc
},
self
)
reporter
.
report
({
"loss_att"
:
loss_att
},
self
)
reporter
.
report
({
"acc"
:
acc
},
self
)
reporter
.
report
({
"cer_ctc"
:
cer_ctc
},
self
)
reporter
.
report
({
"cer"
:
cer
},
self
)
reporter
.
report
({
"wer"
:
wer
},
self
)
logging
.
info
(
"mtl loss:"
+
str
(
mtl_loss
))
reporter
.
report
({
"loss"
:
mtl_loss
},
self
)
class
E2E
(
ASRInterface
,
torch
.
nn
.
Module
):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments."""
E2E
.
encoder_add_arguments
(
parser
)
E2E
.
attention_add_arguments
(
parser
)
E2E
.
decoder_add_arguments
(
parser
)
return
parser
@
staticmethod
def
encoder_add_arguments
(
parser
):
"""Add arguments for the encoder."""
group
=
parser
.
add_argument_group
(
"E2E encoder setting"
)
group
=
add_arguments_rnn_encoder_common
(
group
)
return
parser
@
staticmethod
def
attention_add_arguments
(
parser
):
"""Add arguments for the attention."""
group
=
parser
.
add_argument_group
(
"E2E attention setting"
)
group
=
add_arguments_rnn_attention_common
(
group
)
return
parser
@
staticmethod
def
decoder_add_arguments
(
parser
):
"""Add arguments for the decoder."""
group
=
parser
.
add_argument_group
(
"E2E decoder setting"
)
group
=
add_arguments_rnn_decoder_common
(
group
)
return
parser
def
get_total_subsampling_factor
(
self
):
"""Get total subsampling factor."""
if
isinstance
(
self
.
enc
,
torch
.
nn
.
ModuleList
):
return
self
.
enc
[
0
].
conv_subsampling_factor
*
int
(
np
.
prod
(
self
.
subsample
))
else
:
return
self
.
enc
.
conv_subsampling_factor
*
int
(
np
.
prod
(
self
.
subsample
))
def
__init__
(
self
,
idim
,
odim
,
args
):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
super
(
E2E
,
self
).
__init__
()
torch
.
nn
.
Module
.
__init__
(
self
)
# fill missing arguments for compatibility
args
=
fill_missing_args
(
args
,
self
.
add_arguments
)
self
.
mtlalpha
=
args
.
mtlalpha
assert
0.0
<=
self
.
mtlalpha
<=
1.0
,
"mtlalpha should be [0.0, 1.0]"
self
.
etype
=
args
.
etype
self
.
verbose
=
args
.
verbose
# NOTE: for self.build method
args
.
char_list
=
getattr
(
args
,
"char_list"
,
None
)
self
.
char_list
=
args
.
char_list
self
.
outdir
=
args
.
outdir
self
.
space
=
args
.
sym_space
self
.
blank
=
args
.
sym_blank
self
.
reporter
=
Reporter
()
# below means the last number becomes eos/sos ID
# note that sos/eos IDs are identical
self
.
sos
=
odim
-
1
self
.
eos
=
odim
-
1
# subsample info
self
.
subsample
=
get_subsample
(
args
,
mode
=
"asr"
,
arch
=
"rnn"
)
# label smoothing info
if
args
.
lsm_type
and
os
.
path
.
isfile
(
args
.
train_json
):
logging
.
info
(
"Use label smoothing with "
+
args
.
lsm_type
)
labeldist
=
label_smoothing_dist
(
odim
,
args
.
lsm_type
,
transcript
=
args
.
train_json
)
else
:
labeldist
=
None
if
getattr
(
args
,
"use_frontend"
,
False
):
# use getattr to keep compatibility
self
.
frontend
=
frontend_for
(
args
,
idim
)
self
.
feature_transform
=
feature_transform_for
(
args
,
(
idim
-
1
)
*
2
)
idim
=
args
.
n_mels
else
:
self
.
frontend
=
None
# encoder
self
.
enc
=
encoder_for
(
args
,
idim
,
self
.
subsample
)
# ctc
self
.
ctc
=
ctc_for
(
args
,
odim
)
# attention
self
.
att
=
att_for
(
args
)
# decoder
self
.
dec
=
decoder_for
(
args
,
odim
,
self
.
sos
,
self
.
eos
,
self
.
att
,
labeldist
)
# weight initialization
self
.
init_like_chainer
()
# options for beam search
if
args
.
report_cer
or
args
.
report_wer
:
recog_args
=
{
"beam_size"
:
args
.
beam_size
,
"penalty"
:
args
.
penalty
,
"ctc_weight"
:
args
.
ctc_weight
,
"maxlenratio"
:
args
.
maxlenratio
,
"minlenratio"
:
args
.
minlenratio
,
"lm_weight"
:
args
.
lm_weight
,
"rnnlm"
:
args
.
rnnlm
,
"nbest"
:
args
.
nbest
,
"space"
:
args
.
sym_space
,
"blank"
:
args
.
sym_blank
,
}
self
.
recog_args
=
argparse
.
Namespace
(
**
recog_args
)
self
.
report_cer
=
args
.
report_cer
self
.
report_wer
=
args
.
report_wer
else
:
self
.
report_cer
=
False
self
.
report_wer
=
False
self
.
rnnlm
=
None
self
.
logzero
=
-
10000000000.0
self
.
loss
=
None
self
.
acc
=
None
def
init_like_chainer
(
self
):
"""Initialize weight like chainer.
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
however, there are two exceptions as far as I know.
- EmbedID.W ~ Normal(0, 1)
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
"""
lecun_normal_init_parameters
(
self
)
# exceptions
# embed weight ~ Normal(0, 1)
self
.
dec
.
embed
.
weight
.
data
.
normal_
(
0
,
1
)
# forget-bias = 1.0
# https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
for
i
in
range
(
len
(
self
.
dec
.
decoder
)):
set_forget_bias_to_one
(
self
.
dec
.
decoder
[
i
].
bias_ih
)
def
forward
(
self
,
xs_pad
,
ilens
,
ys_pad
):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: loss value
:rtype: torch.Tensor
"""
import
editdistance
# 0. Frontend
if
self
.
frontend
is
not
None
:
hs_pad
,
hlens
,
mask
=
self
.
frontend
(
to_torch_tensor
(
xs_pad
),
ilens
)
hs_pad
,
hlens
=
self
.
feature_transform
(
hs_pad
,
hlens
)
else
:
hs_pad
,
hlens
=
xs_pad
,
ilens
# 1. Encoder
hs_pad
,
hlens
,
_
=
self
.
enc
(
hs_pad
,
hlens
)
# 2. CTC loss
if
self
.
mtlalpha
==
0
:
self
.
loss_ctc
=
None
else
:
self
.
loss_ctc
=
self
.
ctc
(
hs_pad
,
hlens
,
ys_pad
)
# 3. attention loss
if
self
.
mtlalpha
==
1
:
self
.
loss_att
,
acc
=
None
,
None
else
:
self
.
loss_att
,
acc
,
_
=
self
.
dec
(
hs_pad
,
hlens
,
ys_pad
)
self
.
acc
=
acc
# 4. compute cer without beam search
if
self
.
mtlalpha
==
0
or
self
.
char_list
is
None
:
cer_ctc
=
None
else
:
cers
=
[]
y_hats
=
self
.
ctc
.
argmax
(
hs_pad
).
data
for
i
,
y
in
enumerate
(
y_hats
):
y_hat
=
[
x
[
0
]
for
x
in
groupby
(
y
)]
y_true
=
ys_pad
[
i
]
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
if
int
(
idx
)
!=
-
1
]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
blank
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
hyp_chars
=
seq_hat_text
.
replace
(
" "
,
""
)
ref_chars
=
seq_true_text
.
replace
(
" "
,
""
)
if
len
(
ref_chars
)
>
0
:
cers
.
append
(
editdistance
.
eval
(
hyp_chars
,
ref_chars
)
/
len
(
ref_chars
)
)
cer_ctc
=
sum
(
cers
)
/
len
(
cers
)
if
cers
else
None
# 5. compute cer/wer
if
self
.
training
or
not
(
self
.
report_cer
or
self
.
report_wer
):
cer
,
wer
=
0.0
,
0.0
# oracle_cer, oracle_wer = 0.0, 0.0
else
:
if
self
.
recog_args
.
ctc_weight
>
0.0
:
lpz
=
self
.
ctc
.
log_softmax
(
hs_pad
).
data
else
:
lpz
=
None
word_eds
,
word_ref_lens
,
char_eds
,
char_ref_lens
=
[],
[],
[],
[]
nbest_hyps
=
self
.
dec
.
recognize_beam_batch
(
hs_pad
,
torch
.
tensor
(
hlens
),
lpz
,
self
.
recog_args
,
self
.
char_list
,
self
.
rnnlm
,
)
# remove <sos> and <eos>
y_hats
=
[
nbest_hyp
[
0
][
"yseq"
][
1
:
-
1
]
for
nbest_hyp
in
nbest_hyps
]
for
i
,
y_hat
in
enumerate
(
y_hats
):
y_true
=
ys_pad
[
i
]
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
if
int
(
idx
)
!=
-
1
]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
recog_args
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
recog_args
.
blank
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
recog_args
.
space
,
" "
)
hyp_words
=
seq_hat_text
.
split
()
ref_words
=
seq_true_text
.
split
()
word_eds
.
append
(
editdistance
.
eval
(
hyp_words
,
ref_words
))
word_ref_lens
.
append
(
len
(
ref_words
))
hyp_chars
=
seq_hat_text
.
replace
(
" "
,
""
)
ref_chars
=
seq_true_text
.
replace
(
" "
,
""
)
char_eds
.
append
(
editdistance
.
eval
(
hyp_chars
,
ref_chars
))
char_ref_lens
.
append
(
len
(
ref_chars
))
wer
=
(
0.0
if
not
self
.
report_wer
else
float
(
sum
(
word_eds
))
/
sum
(
word_ref_lens
)
)
cer
=
(
0.0
if
not
self
.
report_cer
else
float
(
sum
(
char_eds
))
/
sum
(
char_ref_lens
)
)
alpha
=
self
.
mtlalpha
if
alpha
==
0
:
self
.
loss
=
self
.
loss_att
loss_att_data
=
float
(
self
.
loss_att
)
loss_ctc_data
=
None
elif
alpha
==
1
:
self
.
loss
=
self
.
loss_ctc
loss_att_data
=
None
loss_ctc_data
=
float
(
self
.
loss_ctc
)
else
:
self
.
loss
=
alpha
*
self
.
loss_ctc
+
(
1
-
alpha
)
*
self
.
loss_att
loss_att_data
=
float
(
self
.
loss_att
)
loss_ctc_data
=
float
(
self
.
loss_ctc
)
loss_data
=
float
(
self
.
loss
)
if
loss_data
<
CTC_LOSS_THRESHOLD
and
not
math
.
isnan
(
loss_data
):
self
.
reporter
.
report
(
loss_ctc_data
,
loss_att_data
,
acc
,
cer_ctc
,
cer
,
wer
,
loss_data
)
else
:
logging
.
warning
(
"loss (=%f) is not correct"
,
loss_data
)
return
self
.
loss
def
scorers
(
self
):
"""Scorers."""
return
dict
(
decoder
=
self
.
dec
,
ctc
=
CTCPrefixScorer
(
self
.
ctc
,
self
.
eos
))
def
encode
(
self
,
x
):
"""Encode acoustic features.
:param ndarray x: input acoustic feature (T, D)
:return: encoder outputs
:rtype: torch.Tensor
"""
self
.
eval
()
ilens
=
[
x
.
shape
[
0
]]
# subsample frame
x
=
x
[::
self
.
subsample
[
0
],
:]
p
=
next
(
self
.
parameters
())
h
=
torch
.
as_tensor
(
x
,
device
=
p
.
device
,
dtype
=
p
.
dtype
)
# make a utt list (1) to use the same interface for encoder
hs
=
h
.
contiguous
().
unsqueeze
(
0
)
# 0. Frontend
if
self
.
frontend
is
not
None
:
enhanced
,
hlens
,
mask
=
self
.
frontend
(
hs
,
ilens
)
hs
,
hlens
=
self
.
feature_transform
(
enhanced
,
hlens
)
else
:
hs
,
hlens
=
hs
,
ilens
# 1. encoder
hs
,
_
,
_
=
self
.
enc
(
hs
,
hlens
)
return
hs
.
squeeze
(
0
)
def
recognize
(
self
,
x
,
recog_args
,
char_list
,
rnnlm
=
None
):
"""E2E beam search.
:param ndarray x: input acoustic feature (T, D)
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
hs
=
self
.
encode
(
x
).
unsqueeze
(
0
)
# calculate log P(z_t|X) for CTC scores
if
recog_args
.
ctc_weight
>
0.0
:
lpz
=
self
.
ctc
.
log_softmax
(
hs
)[
0
]
else
:
lpz
=
None
# 2. Decoder
# decode the first utterance
y
=
self
.
dec
.
recognize_beam
(
hs
[
0
],
lpz
,
recog_args
,
char_list
,
rnnlm
)
return
y
def
recognize_batch
(
self
,
xs
,
recog_args
,
char_list
,
rnnlm
=
None
):
"""E2E batch beam search.
:param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...]
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev
=
self
.
training
self
.
eval
()
ilens
=
np
.
fromiter
((
xx
.
shape
[
0
]
for
xx
in
xs
),
dtype
=
np
.
int64
)
# subsample frame
xs
=
[
xx
[::
self
.
subsample
[
0
],
:]
for
xx
in
xs
]
xs
=
[
to_device
(
self
,
to_torch_tensor
(
xx
).
float
())
for
xx
in
xs
]
xs_pad
=
pad_list
(
xs
,
0.0
)
# 0. Frontend
if
self
.
frontend
is
not
None
:
enhanced
,
hlens
,
mask
=
self
.
frontend
(
xs_pad
,
ilens
)
hs_pad
,
hlens
=
self
.
feature_transform
(
enhanced
,
hlens
)
else
:
hs_pad
,
hlens
=
xs_pad
,
ilens
# 1. Encoder
hs_pad
,
hlens
,
_
=
self
.
enc
(
hs_pad
,
hlens
)
# calculate log P(z_t|X) for CTC scores
if
recog_args
.
ctc_weight
>
0.0
:
lpz
=
self
.
ctc
.
log_softmax
(
hs_pad
)
normalize_score
=
False
else
:
lpz
=
None
normalize_score
=
True
# 2. Decoder
hlens
=
torch
.
tensor
(
list
(
map
(
int
,
hlens
)))
# make sure hlens is tensor
y
=
self
.
dec
.
recognize_beam_batch
(
hs_pad
,
hlens
,
lpz
,
recog_args
,
char_list
,
rnnlm
,
normalize_score
=
normalize_score
,
)
if
prev
:
self
.
train
()
return
y
def
enhance
(
self
,
xs
):
"""Forward only in the frontend stage.
:param ndarray xs: input acoustic feature (T, C, F)
:return: enhaned feature
:rtype: torch.Tensor
"""
if
self
.
frontend
is
None
:
raise
RuntimeError
(
"Frontend does't exist"
)
prev
=
self
.
training
self
.
eval
()
ilens
=
np
.
fromiter
((
xx
.
shape
[
0
]
for
xx
in
xs
),
dtype
=
np
.
int64
)
# subsample frame
xs
=
[
xx
[::
self
.
subsample
[
0
],
:]
for
xx
in
xs
]
xs
=
[
to_device
(
self
,
to_torch_tensor
(
xx
).
float
())
for
xx
in
xs
]
xs_pad
=
pad_list
(
xs
,
0.0
)
enhanced
,
hlensm
,
mask
=
self
.
frontend
(
xs_pad
,
ilens
)
if
prev
:
self
.
train
()
return
enhanced
.
cpu
().
numpy
(),
mask
.
cpu
().
numpy
(),
ilens
def
calculate_all_attentions
(
self
,
xs_pad
,
ilens
,
ys_pad
):
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
self
.
eval
()
with
torch
.
no_grad
():
# 0. Frontend
if
self
.
frontend
is
not
None
:
hs_pad
,
hlens
,
mask
=
self
.
frontend
(
to_torch_tensor
(
xs_pad
),
ilens
)
hs_pad
,
hlens
=
self
.
feature_transform
(
hs_pad
,
hlens
)
else
:
hs_pad
,
hlens
=
xs_pad
,
ilens
# 1. Encoder
hpad
,
hlens
,
_
=
self
.
enc
(
hs_pad
,
hlens
)
# 2. Decoder
att_ws
=
self
.
dec
.
calculate_all_attentions
(
hpad
,
hlens
,
ys_pad
)
self
.
train
()
return
att_ws
def
calculate_all_ctc_probs
(
self
,
xs_pad
,
ilens
,
ys_pad
):
"""E2E CTC probability calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax)
:return: CTC probability (B, Tmax, vocab)
:rtype: float ndarray
"""
probs
=
None
if
self
.
mtlalpha
==
0
:
return
probs
self
.
eval
()
with
torch
.
no_grad
():
# 0. Frontend
if
self
.
frontend
is
not
None
:
hs_pad
,
hlens
,
mask
=
self
.
frontend
(
to_torch_tensor
(
xs_pad
),
ilens
)
hs_pad
,
hlens
=
self
.
feature_transform
(
hs_pad
,
hlens
)
else
:
hs_pad
,
hlens
=
xs_pad
,
ilens
# 1. Encoder
hpad
,
hlens
,
_
=
self
.
enc
(
hs_pad
,
hlens
)
# 2. CTC probs
probs
=
self
.
ctc
.
softmax
(
hpad
).
cpu
().
numpy
()
self
.
train
()
return
probs
def
subsample_frames
(
self
,
x
):
"""Subsample speeh frames in the encoder."""
# subsample frame
x
=
x
[::
self
.
subsample
[
0
],
:]
ilen
=
[
x
.
shape
[
0
]]
h
=
to_device
(
self
,
torch
.
from_numpy
(
np
.
array
(
x
,
dtype
=
np
.
float32
)))
h
.
contiguous
()
return
h
,
ilen
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/e2e_asr_conformer.py
0 → 100644
View file @
60a2c57a
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
Conformer speech recognition model (pytorch).
It is a fusion of `e2e_asr_transformer.py`
Refer to: https://arxiv.org/abs/2005.08100
"""
from
espnet.nets.pytorch_backend.conformer.argument
import
(
# noqa: H301
add_arguments_conformer_common
,
verify_rel_pos_type
,
)
from
espnet.nets.pytorch_backend.conformer.encoder
import
Encoder
from
espnet.nets.pytorch_backend.e2e_asr_transformer
import
E2E
as
E2ETransformer
class
E2E
(
E2ETransformer
):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments."""
E2ETransformer
.
add_arguments
(
parser
)
E2E
.
add_conformer_arguments
(
parser
)
return
parser
@
staticmethod
def
add_conformer_arguments
(
parser
):
"""Add arguments for conformer model."""
group
=
parser
.
add_argument_group
(
"conformer model specific setting"
)
group
=
add_arguments_conformer_common
(
group
)
return
parser
def
__init__
(
self
,
idim
,
odim
,
args
,
ignore_id
=-
1
):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
super
().
__init__
(
idim
,
odim
,
args
,
ignore_id
)
if
args
.
transformer_attn_dropout_rate
is
None
:
args
.
transformer_attn_dropout_rate
=
args
.
dropout_rate
# Check the relative positional encoding type
args
=
verify_rel_pos_type
(
args
)
self
.
encoder
=
Encoder
(
idim
=
idim
,
attention_dim
=
args
.
adim
,
attention_heads
=
args
.
aheads
,
linear_units
=
args
.
eunits
,
num_blocks
=
args
.
elayers
,
input_layer
=
args
.
transformer_input_layer
,
dropout_rate
=
args
.
dropout_rate
,
positional_dropout_rate
=
args
.
dropout_rate
,
attention_dropout_rate
=
args
.
transformer_attn_dropout_rate
,
pos_enc_layer_type
=
args
.
transformer_encoder_pos_enc_layer_type
,
selfattention_layer_type
=
args
.
transformer_encoder_selfattn_layer_type
,
activation_type
=
args
.
transformer_encoder_activation_type
,
macaron_style
=
args
.
macaron_style
,
use_cnn_module
=
args
.
use_cnn_module
,
zero_triu
=
args
.
zero_triu
,
cnn_module_kernel
=
args
.
cnn_module_kernel
,
stochastic_depth_rate
=
args
.
stochastic_depth_rate
,
intermediate_layers
=
self
.
intermediate_ctc_layers
,
ctc_softmax
=
self
.
ctc
.
softmax
if
args
.
self_conditioning
else
None
,
conditioning_layer_dim
=
odim
,
)
self
.
reset_parameters
(
args
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/e2e_asr_maskctc.py
0 → 100644
View file @
60a2c57a
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Waseda University (Yosuke Higuchi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
Mask CTC based non-autoregressive speech recognition model (pytorch).
See https://arxiv.org/abs/2005.08700 for the detail.
"""
import
logging
import
math
from
distutils.util
import
strtobool
from
itertools
import
groupby
import
numpy
import
torch
from
espnet.nets.pytorch_backend.conformer.argument
import
(
# noqa: H301
add_arguments_conformer_common
,
)
from
espnet.nets.pytorch_backend.conformer.encoder
import
Encoder
from
espnet.nets.pytorch_backend.e2e_asr
import
CTC_LOSS_THRESHOLD
from
espnet.nets.pytorch_backend.e2e_asr_transformer
import
E2E
as
E2ETransformer
from
espnet.nets.pytorch_backend.maskctc.add_mask_token
import
mask_uniform
from
espnet.nets.pytorch_backend.maskctc.mask
import
square_mask
from
espnet.nets.pytorch_backend.nets_utils
import
make_non_pad_mask
,
th_accuracy
class
E2E
(
E2ETransformer
):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments."""
E2ETransformer
.
add_arguments
(
parser
)
E2E
.
add_maskctc_arguments
(
parser
)
return
parser
@
staticmethod
def
add_maskctc_arguments
(
parser
):
"""Add arguments for maskctc model."""
group
=
parser
.
add_argument_group
(
"maskctc specific setting"
)
group
.
add_argument
(
"--maskctc-use-conformer-encoder"
,
default
=
False
,
type
=
strtobool
,
)
group
=
add_arguments_conformer_common
(
group
)
return
parser
def
__init__
(
self
,
idim
,
odim
,
args
,
ignore_id
=-
1
):
"""Construct an E2E object.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
odim
+=
1
# for the mask token
super
().
__init__
(
idim
,
odim
,
args
,
ignore_id
)
assert
0.0
<=
self
.
mtlalpha
<
1.0
,
"mtlalpha should be [0.0, 1.0)"
self
.
mask_token
=
odim
-
1
self
.
sos
=
odim
-
2
self
.
eos
=
odim
-
2
self
.
odim
=
odim
self
.
intermediate_ctc_weight
=
args
.
intermediate_ctc_weight
self
.
intermediate_ctc_layers
=
None
if
args
.
intermediate_ctc_layer
!=
""
:
self
.
intermediate_ctc_layers
=
[
int
(
i
)
for
i
in
args
.
intermediate_ctc_layer
.
split
(
","
)
]
if
args
.
maskctc_use_conformer_encoder
:
if
args
.
transformer_attn_dropout_rate
is
None
:
args
.
transformer_attn_dropout_rate
=
args
.
conformer_dropout_rate
self
.
encoder
=
Encoder
(
idim
=
idim
,
attention_dim
=
args
.
adim
,
attention_heads
=
args
.
aheads
,
linear_units
=
args
.
eunits
,
num_blocks
=
args
.
elayers
,
input_layer
=
args
.
transformer_input_layer
,
dropout_rate
=
args
.
dropout_rate
,
positional_dropout_rate
=
args
.
dropout_rate
,
attention_dropout_rate
=
args
.
transformer_attn_dropout_rate
,
pos_enc_layer_type
=
args
.
transformer_encoder_pos_enc_layer_type
,
selfattention_layer_type
=
args
.
transformer_encoder_selfattn_layer_type
,
activation_type
=
args
.
transformer_encoder_activation_type
,
macaron_style
=
args
.
macaron_style
,
use_cnn_module
=
args
.
use_cnn_module
,
cnn_module_kernel
=
args
.
cnn_module_kernel
,
stochastic_depth_rate
=
args
.
stochastic_depth_rate
,
intermediate_layers
=
self
.
intermediate_ctc_layers
,
)
self
.
reset_parameters
(
args
)
def
forward
(
self
,
xs_pad
,
ilens
,
ys_pad
):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of source sequences (B)
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:return: ctc loss value
:rtype: torch.Tensor
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy in attention decoder
:rtype: float
"""
# 1. forward encoder
xs_pad
=
xs_pad
[:,
:
max
(
ilens
)]
# for data parallel
src_mask
=
make_non_pad_mask
(
ilens
.
tolist
()).
to
(
xs_pad
.
device
).
unsqueeze
(
-
2
)
if
self
.
intermediate_ctc_layers
:
hs_pad
,
hs_mask
,
hs_intermediates
=
self
.
encoder
(
xs_pad
,
src_mask
)
else
:
hs_pad
,
hs_mask
=
self
.
encoder
(
xs_pad
,
src_mask
)
self
.
hs_pad
=
hs_pad
# 2. forward decoder
ys_in_pad
,
ys_out_pad
=
mask_uniform
(
ys_pad
,
self
.
mask_token
,
self
.
eos
,
self
.
ignore_id
)
ys_mask
=
square_mask
(
ys_in_pad
,
self
.
eos
)
pred_pad
,
pred_mask
=
self
.
decoder
(
ys_in_pad
,
ys_mask
,
hs_pad
,
hs_mask
)
self
.
pred_pad
=
pred_pad
# 3. compute attention loss
loss_att
=
self
.
criterion
(
pred_pad
,
ys_out_pad
)
self
.
acc
=
th_accuracy
(
pred_pad
.
view
(
-
1
,
self
.
odim
),
ys_out_pad
,
ignore_label
=
self
.
ignore_id
)
# 4. compute ctc loss
loss_ctc
,
cer_ctc
=
None
,
None
loss_intermediate_ctc
=
0.0
if
self
.
mtlalpha
>
0
:
batch_size
=
xs_pad
.
size
(
0
)
hs_len
=
hs_mask
.
view
(
batch_size
,
-
1
).
sum
(
1
)
loss_ctc
=
self
.
ctc
(
hs_pad
.
view
(
batch_size
,
-
1
,
self
.
adim
),
hs_len
,
ys_pad
)
if
self
.
error_calculator
is
not
None
:
ys_hat
=
self
.
ctc
.
argmax
(
hs_pad
.
view
(
batch_size
,
-
1
,
self
.
adim
)).
data
cer_ctc
=
self
.
error_calculator
(
ys_hat
.
cpu
(),
ys_pad
.
cpu
(),
is_ctc
=
True
)
# for visualization
if
not
self
.
training
:
self
.
ctc
.
softmax
(
hs_pad
)
if
self
.
intermediate_ctc_weight
>
0
and
self
.
intermediate_ctc_layers
:
for
hs_intermediate
in
hs_intermediates
:
# assuming hs_intermediates and hs_pad has same length / padding
loss_inter
=
self
.
ctc
(
hs_intermediate
.
view
(
batch_size
,
-
1
,
self
.
adim
),
hs_len
,
ys_pad
)
loss_intermediate_ctc
+=
loss_inter
loss_intermediate_ctc
/=
len
(
self
.
intermediate_ctc_layers
)
# 5. compute cer/wer
if
self
.
training
or
self
.
error_calculator
is
None
or
self
.
decoder
is
None
:
cer
,
wer
=
None
,
None
else
:
ys_hat
=
pred_pad
.
argmax
(
dim
=-
1
)
cer
,
wer
=
self
.
error_calculator
(
ys_hat
.
cpu
(),
ys_pad
.
cpu
())
alpha
=
self
.
mtlalpha
if
alpha
==
0
:
self
.
loss
=
loss_att
loss_att_data
=
float
(
loss_att
)
loss_ctc_data
=
None
else
:
self
.
loss
=
(
alpha
*
loss_ctc
+
self
.
intermediate_ctc_weight
*
loss_intermediate_ctc
+
(
1
-
alpha
-
self
.
intermediate_ctc_weight
)
*
loss_att
)
loss_att_data
=
float
(
loss_att
)
loss_ctc_data
=
float
(
loss_ctc
)
loss_data
=
float
(
self
.
loss
)
if
loss_data
<
CTC_LOSS_THRESHOLD
and
not
math
.
isnan
(
loss_data
):
self
.
reporter
.
report
(
loss_ctc_data
,
loss_att_data
,
self
.
acc
,
cer_ctc
,
cer
,
wer
,
loss_data
)
else
:
logging
.
warning
(
"loss (=%f) is not correct"
,
loss_data
)
return
self
.
loss
def
recognize
(
self
,
x
,
recog_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Recognize input speech.
:param ndnarray x: input acoustic feature (B, T, D) or (T, D)
:param Namespace recog_args: argment Namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: decoding result
:rtype: list
"""
def
num2str
(
char_list
,
mask_token
,
mask_char
=
"_"
):
def
f
(
yl
):
cl
=
[
char_list
[
y
]
if
y
!=
mask_token
else
mask_char
for
y
in
yl
]
return
""
.
join
(
cl
).
replace
(
"<space>"
,
" "
)
return
f
n2s
=
num2str
(
char_list
,
self
.
mask_token
)
self
.
eval
()
h
=
self
.
encode
(
x
).
unsqueeze
(
0
)
input_len
=
h
.
squeeze
(
0
)
logging
.
info
(
"input lengths: "
+
str
(
input_len
.
size
(
0
)))
# greedy ctc outputs
ctc_probs
,
ctc_ids
=
torch
.
exp
(
self
.
ctc
.
log_softmax
(
h
)).
max
(
dim
=-
1
)
y_hat
=
torch
.
stack
([
x
[
0
]
for
x
in
groupby
(
ctc_ids
[
0
])])
y_idx
=
torch
.
nonzero
(
y_hat
!=
0
).
squeeze
(
-
1
)
# calculate token-level ctc probabilities by taking
# the maximum probability of consecutive frames with
# the same ctc symbols
probs_hat
=
[]
cnt
=
0
for
i
,
y
in
enumerate
(
y_hat
.
tolist
()):
probs_hat
.
append
(
-
1
)
while
cnt
<
ctc_ids
.
shape
[
1
]
and
y
==
ctc_ids
[
0
][
cnt
]:
if
probs_hat
[
i
]
<
ctc_probs
[
0
][
cnt
]:
probs_hat
[
i
]
=
ctc_probs
[
0
][
cnt
].
item
()
cnt
+=
1
probs_hat
=
torch
.
from_numpy
(
numpy
.
array
(
probs_hat
))
# mask ctc outputs based on ctc probabilities
p_thres
=
recog_args
.
maskctc_probability_threshold
mask_idx
=
torch
.
nonzero
(
probs_hat
[
y_idx
]
<
p_thres
).
squeeze
(
-
1
)
confident_idx
=
torch
.
nonzero
(
probs_hat
[
y_idx
]
>=
p_thres
).
squeeze
(
-
1
)
mask_num
=
len
(
mask_idx
)
y_in
=
torch
.
zeros
(
1
,
len
(
y_idx
),
dtype
=
torch
.
long
)
+
self
.
mask_token
y_in
[
0
][
confident_idx
]
=
y_hat
[
y_idx
][
confident_idx
]
logging
.
info
(
"ctc:{}"
.
format
(
n2s
(
y_in
[
0
].
tolist
())))
# iterative decoding
if
not
mask_num
==
0
:
K
=
recog_args
.
maskctc_n_iterations
num_iter
=
K
if
mask_num
>=
K
and
K
>
0
else
mask_num
for
t
in
range
(
num_iter
-
1
):
pred
,
_
=
self
.
decoder
(
y_in
,
None
,
h
,
None
)
pred_score
,
pred_id
=
pred
[
0
][
mask_idx
].
max
(
dim
=-
1
)
cand
=
torch
.
topk
(
pred_score
,
mask_num
//
num_iter
,
-
1
)[
1
]
y_in
[
0
][
mask_idx
[
cand
]]
=
pred_id
[
cand
]
mask_idx
=
torch
.
nonzero
(
y_in
[
0
]
==
self
.
mask_token
).
squeeze
(
-
1
)
logging
.
info
(
"msk:{}"
.
format
(
n2s
(
y_in
[
0
].
tolist
())))
# predict leftover masks (|masks| < mask_num // num_iter)
pred
,
pred_mask
=
self
.
decoder
(
y_in
,
None
,
h
,
None
)
y_in
[
0
][
mask_idx
]
=
pred
[
0
][
mask_idx
].
argmax
(
dim
=-
1
)
logging
.
info
(
"msk:{}"
.
format
(
n2s
(
y_in
[
0
].
tolist
())))
ret
=
y_in
.
tolist
()[
0
]
hyp
=
{
"score"
:
0.0
,
"yseq"
:
[
self
.
sos
]
+
ret
+
[
self
.
eos
]}
return
[
hyp
]
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/e2e_asr_mix.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
"""
This script is used to construct End-to-End models of multi-speaker ASR.
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import
argparse
import
logging
import
math
import
os
import
sys
from
itertools
import
groupby
import
numpy
as
np
import
torch
from
espnet.nets.asr_interface
import
ASRInterface
from
espnet.nets.e2e_asr_common
import
get_vgg2l_odim
,
label_smoothing_dist
from
espnet.nets.pytorch_backend.ctc
import
ctc_for
from
espnet.nets.pytorch_backend.e2e_asr
import
E2E
as
E2EASR
from
espnet.nets.pytorch_backend.e2e_asr
import
Reporter
from
espnet.nets.pytorch_backend.frontends.feature_transform
import
(
# noqa: H301
feature_transform_for
,
)
from
espnet.nets.pytorch_backend.frontends.frontend
import
frontend_for
from
espnet.nets.pytorch_backend.initialization
import
(
lecun_normal_init_parameters
,
set_forget_bias_to_one
,
)
from
espnet.nets.pytorch_backend.nets_utils
import
(
get_subsample
,
make_pad_mask
,
pad_list
,
to_device
,
to_torch_tensor
,
)
from
espnet.nets.pytorch_backend.rnn.attentions
import
att_for
from
espnet.nets.pytorch_backend.rnn.decoders
import
decoder_for
from
espnet.nets.pytorch_backend.rnn.encoders
import
RNNP
,
VGG2L
from
espnet.nets.pytorch_backend.rnn.encoders
import
encoder_for
as
encoder_for_single
CTC_LOSS_THRESHOLD
=
10000
class
PIT
(
object
):
"""Permutation Invariant Training (PIT) module.
:parameter int num_spkrs: number of speakers for PIT process (2 or 3)
"""
def
__init__
(
self
,
num_spkrs
):
"""Initialize PIT module."""
self
.
num_spkrs
=
num_spkrs
# [[0, 1], [1, 0]] or
# [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]]
self
.
perm_choices
=
[]
initial_seq
=
np
.
linspace
(
0
,
num_spkrs
-
1
,
num_spkrs
,
dtype
=
np
.
int64
)
self
.
permutationDFS
(
initial_seq
,
0
)
# [[0, 3], [1, 2]] or
# [[0, 4, 8], [0, 5, 7], [1, 3, 8], [1, 5, 6], [2, 4, 6], [2, 3, 7]]
self
.
loss_perm_idx
=
np
.
linspace
(
0
,
num_spkrs
*
(
num_spkrs
-
1
),
num_spkrs
,
dtype
=
np
.
int64
).
reshape
(
1
,
num_spkrs
)
self
.
loss_perm_idx
=
(
self
.
loss_perm_idx
+
np
.
array
(
self
.
perm_choices
)).
tolist
()
def
min_pit_sample
(
self
,
loss
):
"""Compute the PIT loss for each sample.
:param 1-D torch.Tensor loss: list of losses for one sample,
including [h1r1, h1r2, h2r1, h2r2] or
[h1r1, h1r2, h1r3, h2r1, h2r2, h2r3, h3r1, h3r2, h3r3]
:return minimum loss of best permutation
:rtype torch.Tensor (1)
:return the best permutation
:rtype List: len=2
"""
score_perms
=
(
torch
.
stack
(
[
torch
.
sum
(
loss
[
loss_perm_idx
])
for
loss_perm_idx
in
self
.
loss_perm_idx
]
)
/
self
.
num_spkrs
)
perm_loss
,
min_idx
=
torch
.
min
(
score_perms
,
0
)
permutation
=
self
.
perm_choices
[
min_idx
]
return
perm_loss
,
permutation
def
pit_process
(
self
,
losses
):
"""Compute the PIT loss for a batch.
:param torch.Tensor losses: losses (B, 1|4|9)
:return minimum losses of a batch with best permutation
:rtype torch.Tensor (B)
:return the best permutation
:rtype torch.LongTensor (B, 1|2|3)
"""
bs
=
losses
.
size
(
0
)
ret
=
[
self
.
min_pit_sample
(
losses
[
i
])
for
i
in
range
(
bs
)]
loss_perm
=
torch
.
stack
([
r
[
0
]
for
r
in
ret
],
dim
=
0
).
to
(
losses
.
device
)
# (B)
permutation
=
torch
.
tensor
([
r
[
1
]
for
r
in
ret
]).
long
().
to
(
losses
.
device
)
return
torch
.
mean
(
loss_perm
),
permutation
def
permutationDFS
(
self
,
source
,
start
):
"""Get permutations with DFS.
The final result is all permutations of the 'source' sequence.
e.g. [[1, 2], [2, 1]] or
[[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 2, 1], [3, 1, 2]]
:param np.ndarray source: (num_spkrs, 1), e.g. [1, 2, ..., N]
:param int start: the start point to permute
"""
if
start
==
len
(
source
)
-
1
:
# reach final state
self
.
perm_choices
.
append
(
source
.
tolist
())
for
i
in
range
(
start
,
len
(
source
)):
# swap values at position start and i
source
[
start
],
source
[
i
]
=
source
[
i
],
source
[
start
]
self
.
permutationDFS
(
source
,
start
+
1
)
# reverse the swap
source
[
start
],
source
[
i
]
=
source
[
i
],
source
[
start
]
class
E2E
(
ASRInterface
,
torch
.
nn
.
Module
):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments."""
E2EASR
.
encoder_add_arguments
(
parser
)
E2E
.
encoder_mix_add_arguments
(
parser
)
E2EASR
.
attention_add_arguments
(
parser
)
E2EASR
.
decoder_add_arguments
(
parser
)
return
parser
@
staticmethod
def
encoder_mix_add_arguments
(
parser
):
"""Add arguments for multi-speaker encoder."""
group
=
parser
.
add_argument_group
(
"E2E encoder setting for multi-speaker"
)
# asr-mix encoder
group
.
add_argument
(
"--spa"
,
action
=
"store_true"
,
help
=
"Enable speaker parallel attention "
"for multi-speaker speech recognition task."
,
)
group
.
add_argument
(
"--elayers-sd"
,
default
=
4
,
type
=
int
,
help
=
"Number of speaker differentiate encoder layers"
"for multi-speaker speech recognition task."
,
)
return
parser
def
get_total_subsampling_factor
(
self
):
"""Get total subsampling factor."""
return
self
.
enc
.
conv_subsampling_factor
*
int
(
np
.
prod
(
self
.
subsample
))
def
__init__
(
self
,
idim
,
odim
,
args
):
"""Initialize multi-speaker E2E module."""
super
(
E2E
,
self
).
__init__
()
torch
.
nn
.
Module
.
__init__
(
self
)
self
.
mtlalpha
=
args
.
mtlalpha
assert
0.0
<=
self
.
mtlalpha
<=
1.0
,
"mtlalpha should be [0.0, 1.0]"
self
.
etype
=
args
.
etype
self
.
verbose
=
args
.
verbose
# NOTE: for self.build method
args
.
char_list
=
getattr
(
args
,
"char_list"
,
None
)
self
.
char_list
=
args
.
char_list
self
.
outdir
=
args
.
outdir
self
.
space
=
args
.
sym_space
self
.
blank
=
args
.
sym_blank
self
.
reporter
=
Reporter
()
self
.
num_spkrs
=
args
.
num_spkrs
self
.
spa
=
args
.
spa
self
.
pit
=
PIT
(
self
.
num_spkrs
)
# below means the last number becomes eos/sos ID
# note that sos/eos IDs are identical
self
.
sos
=
odim
-
1
self
.
eos
=
odim
-
1
# subsample info
self
.
subsample
=
get_subsample
(
args
,
mode
=
"asr"
,
arch
=
"rnn_mix"
)
# label smoothing info
if
args
.
lsm_type
and
os
.
path
.
isfile
(
args
.
train_json
):
logging
.
info
(
"Use label smoothing with "
+
args
.
lsm_type
)
labeldist
=
label_smoothing_dist
(
odim
,
args
.
lsm_type
,
transcript
=
args
.
train_json
)
else
:
labeldist
=
None
if
getattr
(
args
,
"use_frontend"
,
False
):
# use getattr to keep compatibility
self
.
frontend
=
frontend_for
(
args
,
idim
)
self
.
feature_transform
=
feature_transform_for
(
args
,
(
idim
-
1
)
*
2
)
idim
=
args
.
n_mels
else
:
self
.
frontend
=
None
# encoder
self
.
enc
=
encoder_for
(
args
,
idim
,
self
.
subsample
)
# ctc
self
.
ctc
=
ctc_for
(
args
,
odim
,
reduce
=
False
)
# attention
num_att
=
self
.
num_spkrs
if
args
.
spa
else
1
self
.
att
=
att_for
(
args
,
num_att
)
# decoder
self
.
dec
=
decoder_for
(
args
,
odim
,
self
.
sos
,
self
.
eos
,
self
.
att
,
labeldist
)
# weight initialization
self
.
init_like_chainer
()
# options for beam search
if
"report_cer"
in
vars
(
args
)
and
(
args
.
report_cer
or
args
.
report_wer
):
recog_args
=
{
"beam_size"
:
args
.
beam_size
,
"penalty"
:
args
.
penalty
,
"ctc_weight"
:
args
.
ctc_weight
,
"maxlenratio"
:
args
.
maxlenratio
,
"minlenratio"
:
args
.
minlenratio
,
"lm_weight"
:
args
.
lm_weight
,
"rnnlm"
:
args
.
rnnlm
,
"nbest"
:
args
.
nbest
,
"space"
:
args
.
sym_space
,
"blank"
:
args
.
sym_blank
,
}
self
.
recog_args
=
argparse
.
Namespace
(
**
recog_args
)
self
.
report_cer
=
args
.
report_cer
self
.
report_wer
=
args
.
report_wer
else
:
self
.
report_cer
=
False
self
.
report_wer
=
False
self
.
rnnlm
=
None
self
.
logzero
=
-
10000000000.0
self
.
loss
=
None
self
.
acc
=
None
def
init_like_chainer
(
self
):
"""Initialize weight like chainer.
chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0
pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5)
however, there are two exceptions as far as I know.
- EmbedID.W ~ Normal(0, 1)
- LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM)
"""
lecun_normal_init_parameters
(
self
)
# exceptions
# embed weight ~ Normal(0, 1)
self
.
dec
.
embed
.
weight
.
data
.
normal_
(
0
,
1
)
# forget-bias = 1.0
# https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745
for
i
in
range
(
len
(
self
.
dec
.
decoder
)):
set_forget_bias_to_one
(
self
.
dec
.
decoder
[
i
].
bias_ih
)
def
forward
(
self
,
xs_pad
,
ilens
,
ys_pad
):
"""E2E forward.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, num_spkrs, Lmax)
:return: ctc loss value
:rtype: torch.Tensor
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy in attention decoder
:rtype: float
"""
import
editdistance
# 0. Frontend
if
self
.
frontend
is
not
None
:
hs_pad
,
hlens
,
mask
=
self
.
frontend
(
to_torch_tensor
(
xs_pad
),
ilens
)
if
isinstance
(
hs_pad
,
list
):
hlens_n
=
[
None
]
*
self
.
num_spkrs
for
i
in
range
(
self
.
num_spkrs
):
hs_pad
[
i
],
hlens_n
[
i
]
=
self
.
feature_transform
(
hs_pad
[
i
],
hlens
)
hlens
=
hlens_n
else
:
hs_pad
,
hlens
=
self
.
feature_transform
(
hs_pad
,
hlens
)
else
:
hs_pad
,
hlens
=
xs_pad
,
ilens
# 1. Encoder
if
not
isinstance
(
hs_pad
,
list
):
# single-channel input xs_pad (single- or multi-speaker)
hs_pad
,
hlens
,
_
=
self
.
enc
(
hs_pad
,
hlens
)
else
:
# multi-channel multi-speaker input xs_pad
for
i
in
range
(
self
.
num_spkrs
):
hs_pad
[
i
],
hlens
[
i
],
_
=
self
.
enc
(
hs_pad
[
i
],
hlens
[
i
])
# 2. CTC loss
if
self
.
mtlalpha
==
0
:
loss_ctc
,
min_perm
=
None
,
None
else
:
if
not
isinstance
(
hs_pad
,
list
):
# single-speaker input xs_pad
loss_ctc
=
torch
.
mean
(
self
.
ctc
(
hs_pad
,
hlens
,
ys_pad
))
else
:
# multi-speaker input xs_pad
ys_pad
=
ys_pad
.
transpose
(
0
,
1
)
# (num_spkrs, B, Lmax)
loss_ctc_perm
=
torch
.
stack
(
[
self
.
ctc
(
hs_pad
[
i
//
self
.
num_spkrs
],
hlens
[
i
//
self
.
num_spkrs
],
ys_pad
[
i
%
self
.
num_spkrs
],
)
for
i
in
range
(
self
.
num_spkrs
**
2
)
],
dim
=
1
,
)
# (B, num_spkrs^2)
loss_ctc
,
min_perm
=
self
.
pit
.
pit_process
(
loss_ctc_perm
)
logging
.
info
(
"ctc loss:"
+
str
(
float
(
loss_ctc
)))
# 3. attention loss
if
self
.
mtlalpha
==
1
:
loss_att
=
None
acc
=
None
else
:
if
not
isinstance
(
hs_pad
,
list
):
# single-speaker input xs_pad
loss_att
,
acc
,
_
=
self
.
dec
(
hs_pad
,
hlens
,
ys_pad
)
else
:
for
i
in
range
(
ys_pad
.
size
(
1
)):
# B
ys_pad
[:,
i
]
=
ys_pad
[
min_perm
[
i
],
i
]
rslt
=
[
self
.
dec
(
hs_pad
[
i
],
hlens
[
i
],
ys_pad
[
i
],
strm_idx
=
i
)
for
i
in
range
(
self
.
num_spkrs
)
]
loss_att
=
sum
([
r
[
0
]
for
r
in
rslt
])
/
float
(
len
(
rslt
))
acc
=
sum
([
r
[
1
]
for
r
in
rslt
])
/
float
(
len
(
rslt
))
self
.
acc
=
acc
# 4. compute cer without beam search
if
self
.
mtlalpha
==
0
or
self
.
char_list
is
None
:
cer_ctc
=
None
else
:
cers
=
[]
for
ns
in
range
(
self
.
num_spkrs
):
y_hats
=
self
.
ctc
.
argmax
(
hs_pad
[
ns
]).
data
for
i
,
y
in
enumerate
(
y_hats
):
y_hat
=
[
x
[
0
]
for
x
in
groupby
(
y
)]
y_true
=
ys_pad
[
ns
][
i
]
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
if
int
(
idx
)
!=
-
1
]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
blank
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
space
,
" "
)
hyp_chars
=
seq_hat_text
.
replace
(
" "
,
""
)
ref_chars
=
seq_true_text
.
replace
(
" "
,
""
)
if
len
(
ref_chars
)
>
0
:
cers
.
append
(
editdistance
.
eval
(
hyp_chars
,
ref_chars
)
/
len
(
ref_chars
)
)
cer_ctc
=
sum
(
cers
)
/
len
(
cers
)
if
cers
else
None
# 5. compute cer/wer
if
(
self
.
training
or
not
(
self
.
report_cer
or
self
.
report_wer
)
or
not
isinstance
(
hs_pad
,
list
)
):
cer
,
wer
=
0.0
,
0.0
else
:
if
self
.
recog_args
.
ctc_weight
>
0.0
:
lpz
=
[
self
.
ctc
.
log_softmax
(
hs_pad
[
i
]).
data
for
i
in
range
(
self
.
num_spkrs
)
]
else
:
lpz
=
None
word_eds
,
char_eds
,
word_ref_lens
,
char_ref_lens
=
[],
[],
[],
[]
nbest_hyps
=
[
self
.
dec
.
recognize_beam_batch
(
hs_pad
[
i
],
torch
.
tensor
(
hlens
[
i
]),
lpz
[
i
],
self
.
recog_args
,
self
.
char_list
,
self
.
rnnlm
,
strm_idx
=
i
,
)
for
i
in
range
(
self
.
num_spkrs
)
]
# remove <sos> and <eos>
y_hats
=
[
[
nbest_hyp
[
0
][
"yseq"
][
1
:
-
1
]
for
nbest_hyp
in
nbest_hyps
[
i
]]
for
i
in
range
(
self
.
num_spkrs
)
]
for
i
in
range
(
len
(
y_hats
[
0
])):
hyp_words
=
[]
hyp_chars
=
[]
ref_words
=
[]
ref_chars
=
[]
for
ns
in
range
(
self
.
num_spkrs
):
y_hat
=
y_hats
[
ns
][
i
]
y_true
=
ys_pad
[
ns
][
i
]
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_hat
if
int
(
idx
)
!=
-
1
]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
y_true
if
int
(
idx
)
!=
-
1
]
seq_hat_text
=
""
.
join
(
seq_hat
).
replace
(
self
.
recog_args
.
space
,
" "
)
seq_hat_text
=
seq_hat_text
.
replace
(
self
.
recog_args
.
blank
,
""
)
seq_true_text
=
""
.
join
(
seq_true
).
replace
(
self
.
recog_args
.
space
,
" "
)
hyp_words
.
append
(
seq_hat_text
.
split
())
ref_words
.
append
(
seq_true_text
.
split
())
hyp_chars
.
append
(
seq_hat_text
.
replace
(
" "
,
""
))
ref_chars
.
append
(
seq_true_text
.
replace
(
" "
,
""
))
tmp_word_ed
=
[
editdistance
.
eval
(
hyp_words
[
ns
//
self
.
num_spkrs
],
ref_words
[
ns
%
self
.
num_spkrs
]
)
for
ns
in
range
(
self
.
num_spkrs
**
2
)
]
# h1r1,h1r2,h2r1,h2r2
tmp_char_ed
=
[
editdistance
.
eval
(
hyp_chars
[
ns
//
self
.
num_spkrs
],
ref_chars
[
ns
%
self
.
num_spkrs
]
)
for
ns
in
range
(
self
.
num_spkrs
**
2
)
]
# h1r1,h1r2,h2r1,h2r2
word_eds
.
append
(
self
.
pit
.
min_pit_sample
(
torch
.
tensor
(
tmp_word_ed
))[
0
])
word_ref_lens
.
append
(
len
(
sum
(
ref_words
,
[])))
char_eds
.
append
(
self
.
pit
.
min_pit_sample
(
torch
.
tensor
(
tmp_char_ed
))[
0
])
char_ref_lens
.
append
(
len
(
""
.
join
(
ref_chars
)))
wer
=
(
0.0
if
not
self
.
report_wer
else
float
(
sum
(
word_eds
))
/
sum
(
word_ref_lens
)
)
cer
=
(
0.0
if
not
self
.
report_cer
else
float
(
sum
(
char_eds
))
/
sum
(
char_ref_lens
)
)
alpha
=
self
.
mtlalpha
if
alpha
==
0
:
self
.
loss
=
loss_att
loss_att_data
=
float
(
loss_att
)
loss_ctc_data
=
None
elif
alpha
==
1
:
self
.
loss
=
loss_ctc
loss_att_data
=
None
loss_ctc_data
=
float
(
loss_ctc
)
else
:
self
.
loss
=
alpha
*
loss_ctc
+
(
1
-
alpha
)
*
loss_att
loss_att_data
=
float
(
loss_att
)
loss_ctc_data
=
float
(
loss_ctc
)
loss_data
=
float
(
self
.
loss
)
if
loss_data
<
CTC_LOSS_THRESHOLD
and
not
math
.
isnan
(
loss_data
):
self
.
reporter
.
report
(
loss_ctc_data
,
loss_att_data
,
self
.
acc
,
cer_ctc
,
cer
,
wer
,
loss_data
)
else
:
logging
.
warning
(
"loss (=%f) is not correct"
,
loss_data
)
return
self
.
loss
def
recognize
(
self
,
x
,
recog_args
,
char_list
,
rnnlm
=
None
):
"""E2E beam search.
:param ndarray x: input acoustic feature (T, D)
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev
=
self
.
training
self
.
eval
()
ilens
=
[
x
.
shape
[
0
]]
# subsample frame
x
=
x
[::
self
.
subsample
[
0
],
:]
h
=
to_device
(
self
,
to_torch_tensor
(
x
).
float
())
# make a utt list (1) to use the same interface for encoder
hs
=
h
.
contiguous
().
unsqueeze
(
0
)
# 0. Frontend
if
self
.
frontend
is
not
None
:
hs
,
hlens
,
mask
=
self
.
frontend
(
hs
,
ilens
)
hlens_n
=
[
None
]
*
self
.
num_spkrs
for
i
in
range
(
self
.
num_spkrs
):
hs
[
i
],
hlens_n
[
i
]
=
self
.
feature_transform
(
hs
[
i
],
hlens
)
hlens
=
hlens_n
else
:
hs
,
hlens
=
hs
,
ilens
# 1. Encoder
if
not
isinstance
(
hs
,
list
):
# single-channel multi-speaker input x
hs
,
hlens
,
_
=
self
.
enc
(
hs
,
hlens
)
else
:
# multi-channel multi-speaker input x
for
i
in
range
(
self
.
num_spkrs
):
hs
[
i
],
hlens
[
i
],
_
=
self
.
enc
(
hs
[
i
],
hlens
[
i
])
# calculate log P(z_t|X) for CTC scores
if
recog_args
.
ctc_weight
>
0.0
:
lpz
=
[
self
.
ctc
.
log_softmax
(
i
)[
0
]
for
i
in
hs
]
else
:
lpz
=
None
# 2. decoder
# decode the first utterance
y
=
[
self
.
dec
.
recognize_beam
(
hs
[
i
][
0
],
lpz
[
i
],
recog_args
,
char_list
,
rnnlm
,
strm_idx
=
i
)
for
i
in
range
(
self
.
num_spkrs
)
]
if
prev
:
self
.
train
()
return
y
def
recognize_batch
(
self
,
xs
,
recog_args
,
char_list
,
rnnlm
=
None
):
"""E2E beam search.
:param ndarray xs: input acoustic feature (T, D)
:param Namespace recog_args: argument Namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
prev
=
self
.
training
self
.
eval
()
ilens
=
np
.
fromiter
((
xx
.
shape
[
0
]
for
xx
in
xs
),
dtype
=
np
.
int64
)
# subsample frame
xs
=
[
xx
[::
self
.
subsample
[
0
],
:]
for
xx
in
xs
]
xs
=
[
to_device
(
self
,
to_torch_tensor
(
xx
).
float
())
for
xx
in
xs
]
xs_pad
=
pad_list
(
xs
,
0.0
)
# 0. Frontend
if
self
.
frontend
is
not
None
:
hs_pad
,
hlens
,
mask
=
self
.
frontend
(
xs_pad
,
ilens
)
hlens_n
=
[
None
]
*
self
.
num_spkrs
for
i
in
range
(
self
.
num_spkrs
):
hs_pad
[
i
],
hlens_n
[
i
]
=
self
.
feature_transform
(
hs_pad
[
i
],
hlens
)
hlens
=
hlens_n
else
:
hs_pad
,
hlens
=
xs_pad
,
ilens
# 1. Encoder
if
not
isinstance
(
hs_pad
,
list
):
# single-channel multi-speaker input x
hs_pad
,
hlens
,
_
=
self
.
enc
(
hs_pad
,
hlens
)
else
:
# multi-channel multi-speaker input x
for
i
in
range
(
self
.
num_spkrs
):
hs_pad
[
i
],
hlens
[
i
],
_
=
self
.
enc
(
hs_pad
[
i
],
hlens
[
i
])
# calculate log P(z_t|X) for CTC scores
if
recog_args
.
ctc_weight
>
0.0
:
lpz
=
[
self
.
ctc
.
log_softmax
(
hs_pad
[
i
])
for
i
in
range
(
self
.
num_spkrs
)]
normalize_score
=
False
else
:
lpz
=
None
normalize_score
=
True
# 2. decoder
y
=
[
self
.
dec
.
recognize_beam_batch
(
hs_pad
[
i
],
hlens
[
i
],
lpz
[
i
],
recog_args
,
char_list
,
rnnlm
,
normalize_score
=
normalize_score
,
strm_idx
=
i
,
)
for
i
in
range
(
self
.
num_spkrs
)
]
if
prev
:
self
.
train
()
return
y
def
enhance
(
self
,
xs
):
"""Forward only the frontend stage.
:param ndarray xs: input acoustic feature (T, C, F)
"""
if
self
.
frontend
is
None
:
raise
RuntimeError
(
"Frontend doesn't exist"
)
prev
=
self
.
training
self
.
eval
()
ilens
=
np
.
fromiter
((
xx
.
shape
[
0
]
for
xx
in
xs
),
dtype
=
np
.
int64
)
# subsample frame
xs
=
[
xx
[::
self
.
subsample
[
0
],
:]
for
xx
in
xs
]
xs
=
[
to_device
(
self
,
to_torch_tensor
(
xx
).
float
())
for
xx
in
xs
]
xs_pad
=
pad_list
(
xs
,
0.0
)
enhanced
,
hlensm
,
mask
=
self
.
frontend
(
xs_pad
,
ilens
)
if
prev
:
self
.
train
()
if
isinstance
(
enhanced
,
(
tuple
,
list
)):
enhanced
=
list
(
enhanced
)
mask
=
list
(
mask
)
for
idx
in
range
(
len
(
enhanced
)):
# number of speakers
enhanced
[
idx
]
=
enhanced
[
idx
].
cpu
().
numpy
()
mask
[
idx
]
=
mask
[
idx
].
cpu
().
numpy
()
return
enhanced
,
mask
,
ilens
return
enhanced
.
cpu
().
numpy
(),
mask
.
cpu
().
numpy
(),
ilens
def
calculate_all_attentions
(
self
,
xs_pad
,
ilens
,
ys_pad
):
"""E2E attention calculation.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, num_spkrs, Lmax)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
with
torch
.
no_grad
():
# 0. Frontend
if
self
.
frontend
is
not
None
:
hs_pad
,
hlens
,
mask
=
self
.
frontend
(
to_torch_tensor
(
xs_pad
),
ilens
)
hlens_n
=
[
None
]
*
self
.
num_spkrs
for
i
in
range
(
self
.
num_spkrs
):
hs_pad
[
i
],
hlens_n
[
i
]
=
self
.
feature_transform
(
hs_pad
[
i
],
hlens
)
hlens
=
hlens_n
else
:
hs_pad
,
hlens
=
xs_pad
,
ilens
# 1. Encoder
if
not
isinstance
(
hs_pad
,
list
):
# single-channel multi-speaker input x
hs_pad
,
hlens
,
_
=
self
.
enc
(
hs_pad
,
hlens
)
else
:
# multi-channel multi-speaker input x
for
i
in
range
(
self
.
num_spkrs
):
hs_pad
[
i
],
hlens
[
i
],
_
=
self
.
enc
(
hs_pad
[
i
],
hlens
[
i
])
# Permutation
ys_pad
=
ys_pad
.
transpose
(
0
,
1
)
# (num_spkrs, B, Lmax)
if
self
.
num_spkrs
<=
3
:
loss_ctc
=
torch
.
stack
(
[
self
.
ctc
(
hs_pad
[
i
//
self
.
num_spkrs
],
hlens
[
i
//
self
.
num_spkrs
],
ys_pad
[
i
%
self
.
num_spkrs
],
)
for
i
in
range
(
self
.
num_spkrs
**
2
)
],
1
,
)
# (B, num_spkrs^2)
loss_ctc
,
min_perm
=
self
.
pit
.
pit_process
(
loss_ctc
)
for
i
in
range
(
ys_pad
.
size
(
1
)):
# B
ys_pad
[:,
i
]
=
ys_pad
[
min_perm
[
i
],
i
]
# 2. Decoder
att_ws
=
[
self
.
dec
.
calculate_all_attentions
(
hs_pad
[
i
],
hlens
[
i
],
ys_pad
[
i
],
strm_idx
=
i
)
for
i
in
range
(
self
.
num_spkrs
)
]
return
att_ws
class
EncoderMix
(
torch
.
nn
.
Module
):
"""Encoder module for the case of multi-speaker mixture speech.
:param str etype: type of encoder network
:param int idim: number of dimensions of encoder network
:param int elayers_sd:
number of layers of speaker differentiate part in encoder network
:param int elayers_rec:
number of layers of shared recognition part in encoder network
:param int eunits: number of lstm units of encoder network
:param int eprojs: number of projection units of encoder network
:param np.ndarray subsample: list of subsampling numbers
:param float dropout: dropout rate
:param int in_channel: number of input channels
:param int num_spkrs: number of number of speakers
"""
def
__init__
(
self
,
etype
,
idim
,
elayers_sd
,
elayers_rec
,
eunits
,
eprojs
,
subsample
,
dropout
,
num_spkrs
=
2
,
in_channel
=
1
,
):
"""Initialize the encoder of single-channel multi-speaker ASR."""
super
(
EncoderMix
,
self
).
__init__
()
typ
=
etype
.
lstrip
(
"vgg"
).
rstrip
(
"p"
)
if
typ
not
in
[
"lstm"
,
"gru"
,
"blstm"
,
"bgru"
]:
logging
.
error
(
"Error: need to specify an appropriate encoder architecture"
)
if
etype
.
startswith
(
"vgg"
):
if
etype
[
-
1
]
==
"p"
:
self
.
enc_mix
=
torch
.
nn
.
ModuleList
([
VGG2L
(
in_channel
)])
self
.
enc_sd
=
torch
.
nn
.
ModuleList
(
[
torch
.
nn
.
ModuleList
(
[
RNNP
(
get_vgg2l_odim
(
idim
,
in_channel
=
in_channel
),
elayers_sd
,
eunits
,
eprojs
,
subsample
[:
elayers_sd
+
1
],
dropout
,
typ
=
typ
,
)
]
)
for
i
in
range
(
num_spkrs
)
]
)
self
.
enc_rec
=
torch
.
nn
.
ModuleList
(
[
RNNP
(
eprojs
,
elayers_rec
,
eunits
,
eprojs
,
subsample
[
elayers_sd
:],
dropout
,
typ
=
typ
,
)
]
)
logging
.
info
(
"Use CNN-VGG + B"
+
typ
.
upper
()
+
"P for encoder"
)
else
:
logging
.
error
(
f
"Error: need to specify an appropriate encoder architecture. "
f
"Illegal name
{
etype
}
"
)
sys
.
exit
()
else
:
logging
.
error
(
f
"Error: need to specify an appropriate encoder architecture. "
f
"Illegal name
{
etype
}
"
)
sys
.
exit
()
self
.
num_spkrs
=
num_spkrs
def
forward
(
self
,
xs_pad
,
ilens
):
"""Encodermix forward.
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: list: batch of hidden state sequences [num_spkrs x (B, Tmax, eprojs)]
:rtype: torch.Tensor
"""
# mixture encoder
for
module
in
self
.
enc_mix
:
xs_pad
,
ilens
,
_
=
module
(
xs_pad
,
ilens
)
# SD and Rec encoder
xs_pad_sd
=
[
xs_pad
for
i
in
range
(
self
.
num_spkrs
)]
ilens_sd
=
[
ilens
for
i
in
range
(
self
.
num_spkrs
)]
for
ns
in
range
(
self
.
num_spkrs
):
# Encoder_SD: speaker differentiate encoder
for
module
in
self
.
enc_sd
[
ns
]:
xs_pad_sd
[
ns
],
ilens_sd
[
ns
],
_
=
module
(
xs_pad_sd
[
ns
],
ilens_sd
[
ns
])
# Encoder_Rec: recognition encoder
for
module
in
self
.
enc_rec
:
xs_pad_sd
[
ns
],
ilens_sd
[
ns
],
_
=
module
(
xs_pad_sd
[
ns
],
ilens_sd
[
ns
])
# make mask to remove bias value in padded part
mask
=
to_device
(
xs_pad
,
make_pad_mask
(
ilens_sd
[
0
]).
unsqueeze
(
-
1
))
return
[
x
.
masked_fill
(
mask
,
0.0
)
for
x
in
xs_pad_sd
],
ilens_sd
,
None
def
encoder_for
(
args
,
idim
,
subsample
):
"""Construct the encoder."""
if
getattr
(
args
,
"use_frontend"
,
False
):
# use getattr to keep compatibility
# with frontend, the mixed speech are separated as streams for each speaker
return
encoder_for_single
(
args
,
idim
,
subsample
)
else
:
return
EncoderMix
(
args
.
etype
,
idim
,
args
.
elayers_sd
,
args
.
elayers
,
args
.
eunits
,
args
.
eprojs
,
subsample
,
args
.
dropout_rate
,
args
.
num_spkrs
,
)
Prev
1
2
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment