Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
sunzhq2
yidong-infer
Commits
60a2c57a
Commit
60a2c57a
authored
Jan 27, 2026
by
sunzhq2
Committed by
xuxo
Jan 27, 2026
Browse files
update conformer
parent
4a699441
Changes
216
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4767 additions
and
0 deletions
+4767
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/lm/lm_utils.py
.../espnet-v.202304_20240621/build/lib/espnet/lm/lm_utils.py
+292
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/lm/pytorch_backend/__init__.py
..._20240621/build/lib/espnet/lm/pytorch_backend/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/lm/pytorch_backend/extlm.py
...304_20240621/build/lib/espnet/lm/pytorch_backend/extlm.py
+218
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/lm/pytorch_backend/lm.py
...202304_20240621/build/lib/espnet/lm/pytorch_backend/lm.py
+407
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/mt/__init__.py
.../espnet-v.202304_20240621/build/lib/espnet/mt/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/mt/mt_utils.py
.../espnet-v.202304_20240621/build/lib/espnet/mt/mt_utils.py
+83
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/mt/pytorch_backend/__init__.py
..._20240621/build/lib/espnet/mt/pytorch_backend/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/mt/pytorch_backend/mt.py
...202304_20240621/build/lib/espnet/mt/pytorch_backend/mt.py
+593
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/__init__.py
...spnet-v.202304_20240621/build/lib/espnet/nets/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/asr_interface.py
...-v.202304_20240621/build/lib/espnet/nets/asr_interface.py
+172
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/batch_beam_search.py
...02304_20240621/build/lib/espnet/nets/batch_beam_search.py
+353
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/batch_beam_search_online.py
...0240621/build/lib/espnet/nets/batch_beam_search_online.py
+309
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/batch_beam_search_online_sim.py
...621/build/lib/espnet/nets/batch_beam_search_online_sim.py
+279
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/beam_search.py
...et-v.202304_20240621/build/lib/espnet/nets/beam_search.py
+541
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/beam_search_timesync.py
...04_20240621/build/lib/espnet/nets/beam_search_timesync.py
+239
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/beam_search_transducer.py
..._20240621/build/lib/espnet/nets/beam_search_transducer.py
+896
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/__init__.py
...0240621/build/lib/espnet/nets/chainer_backend/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/asr_interface.py
...21/build/lib/espnet/nets/chainer_backend/asr_interface.py
+29
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/ctc.py
...304_20240621/build/lib/espnet/nets/chainer_backend/ctc.py
+103
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/deterministic_embed_id.py
...lib/espnet/nets/chainer_backend/deterministic_embed_id.py
+248
-0
No files found.
Too many changes to show.
To preserve performance only
216 of 216+
files are displayed.
Plain diff
Email patch
conformer/espnet-v.202304_20240621/build/lib/espnet/lm/lm_utils.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
# This code is ported from the following implementation written in Torch.
# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
import
logging
import
os
import
random
import
chainer
import
h5py
import
numpy
as
np
from
chainer.training
import
extension
from
tqdm
import
tqdm
def
load_dataset
(
path
,
label_dict
,
outdir
=
None
):
"""Load and save HDF5 that contains a dataset and stats for LM
Args:
path (str): The path of an input text dataset file
label_dict (dict[str, int]):
dictionary that maps token label string to its ID number
outdir (str): The path of an output dir
Returns:
tuple[list[np.ndarray], int, int]: Tuple of
token IDs in np.int32 converted by `read_tokens`
the number of tokens by `count_tokens`,
and the number of OOVs by `count_tokens`
"""
if
outdir
is
not
None
:
os
.
makedirs
(
outdir
,
exist_ok
=
True
)
filename
=
outdir
+
"/"
+
os
.
path
.
basename
(
path
)
+
".h5"
if
os
.
path
.
exists
(
filename
):
logging
.
info
(
f
"loading binary dataset:
{
filename
}
"
)
f
=
h5py
.
File
(
filename
,
"r"
)
return
f
[
"data"
][:],
f
[
"n_tokens"
][()],
f
[
"n_oovs"
][()]
else
:
logging
.
info
(
"skip dump/load HDF5 because the output dir is not specified"
)
logging
.
info
(
f
"reading text dataset:
{
path
}
"
)
ret
=
read_tokens
(
path
,
label_dict
)
n_tokens
,
n_oovs
=
count_tokens
(
ret
,
label_dict
[
"<unk>"
])
if
outdir
is
not
None
:
logging
.
info
(
f
"saving binary dataset:
{
filename
}
"
)
with
h5py
.
File
(
filename
,
"w"
)
as
f
:
# http://docs.h5py.org/en/stable/special.html#arbitrary-vlen-data
data
=
f
.
create_dataset
(
"data"
,
(
len
(
ret
),),
dtype
=
h5py
.
special_dtype
(
vlen
=
np
.
int32
)
)
data
[:]
=
ret
f
[
"n_tokens"
]
=
n_tokens
f
[
"n_oovs"
]
=
n_oovs
return
ret
,
n_tokens
,
n_oovs
def
read_tokens
(
filename
,
label_dict
):
"""Read tokens as a sequence of sentences
:param str filename : The name of the input file
:param dict label_dict : dictionary that maps token label string to its ID number
:return list of ID sequences
:rtype list
"""
data
=
[]
unk
=
label_dict
[
"<unk>"
]
for
ln
in
tqdm
(
open
(
filename
,
"r"
,
encoding
=
"utf-8"
)):
data
.
append
(
np
.
array
(
[
label_dict
.
get
(
label
,
unk
)
for
label
in
ln
.
split
()],
dtype
=
np
.
int32
)
)
return
data
def
count_tokens
(
data
,
unk_id
=
None
):
"""Count tokens and oovs in token ID sequences.
Args:
data (list[np.ndarray]): list of token ID sequences
unk_id (int): ID of unknown token
Returns:
tuple: tuple of number of token occurrences and number of oov tokens
"""
n_tokens
=
0
n_oovs
=
0
for
sentence
in
data
:
n_tokens
+=
len
(
sentence
)
if
unk_id
is
not
None
:
n_oovs
+=
np
.
count_nonzero
(
sentence
==
unk_id
)
return
n_tokens
,
n_oovs
def
compute_perplexity
(
result
):
"""Computes and add the perplexity to the LogReport
:param dict result: The current observations
"""
# Routine to rewrite the result dictionary of LogReport to add perplexity values
result
[
"perplexity"
]
=
np
.
exp
(
result
[
"main/loss"
]
/
result
[
"main/count"
])
if
"validation/main/loss"
in
result
:
result
[
"val_perplexity"
]
=
np
.
exp
(
result
[
"validation/main/loss"
])
class
ParallelSentenceIterator
(
chainer
.
dataset
.
Iterator
):
"""Dataset iterator to create a batch of sentences.
This iterator returns a pair of sentences, where one token is shifted
between the sentences like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
Sentence batches are made in order of longer sentences, and then
randomly shuffled.
"""
def
__init__
(
self
,
dataset
,
batch_size
,
max_length
=
0
,
sos
=
0
,
eos
=
0
,
repeat
=
True
,
shuffle
=
True
):
self
.
dataset
=
dataset
self
.
batch_size
=
batch_size
# batch size
# Number of completed sweeps over the dataset. In this case, it is
# incremented if every word is visited at least once after the last
# increment.
self
.
epoch
=
0
# True if the epoch is incremented at the last iteration.
self
.
is_new_epoch
=
False
self
.
repeat
=
repeat
length
=
len
(
dataset
)
self
.
batch_indices
=
[]
# make mini-batches
if
batch_size
>
1
:
indices
=
sorted
(
range
(
len
(
dataset
)),
key
=
lambda
i
:
-
len
(
dataset
[
i
]))
bs
=
0
while
bs
<
length
:
be
=
min
(
bs
+
batch_size
,
length
)
# batch size is automatically reduced if the sentence length
# is larger than max_length
if
max_length
>
0
:
sent_length
=
len
(
dataset
[
indices
[
bs
]])
be
=
min
(
be
,
bs
+
max
(
batch_size
//
(
sent_length
//
max_length
+
1
),
1
)
)
self
.
batch_indices
.
append
(
np
.
array
(
indices
[
bs
:
be
]))
bs
=
be
if
shuffle
:
# shuffle batches
random
.
shuffle
(
self
.
batch_indices
)
else
:
self
.
batch_indices
=
[
np
.
array
([
i
])
for
i
in
range
(
length
)]
# NOTE: this is not a count of parameter updates. It is just a count of
# calls of ``__next__``.
self
.
iteration
=
0
self
.
sos
=
sos
self
.
eos
=
eos
# use -1 instead of None internally
self
.
_previous_epoch_detail
=
-
1.0
def
__next__
(
self
):
# This iterator returns a list representing a mini-batch. Each item
# indicates a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
# represented by token IDs.
n_batches
=
len
(
self
.
batch_indices
)
if
not
self
.
repeat
and
self
.
iteration
>=
n_batches
:
# If not self.repeat, this iterator stops at the end of the first
# epoch (i.e., when all words are visited once).
raise
StopIteration
batch
=
[]
for
idx
in
self
.
batch_indices
[
self
.
iteration
%
n_batches
]:
batch
.
append
(
(
np
.
append
([
self
.
sos
],
self
.
dataset
[
idx
]),
np
.
append
(
self
.
dataset
[
idx
],
[
self
.
eos
]),
)
)
self
.
_previous_epoch_detail
=
self
.
epoch_detail
self
.
iteration
+=
1
epoch
=
self
.
iteration
//
n_batches
self
.
is_new_epoch
=
self
.
epoch
<
epoch
if
self
.
is_new_epoch
:
self
.
epoch
=
epoch
return
batch
def
start_shuffle
(
self
):
random
.
shuffle
(
self
.
batch_indices
)
@
property
def
epoch_detail
(
self
):
# Floating point version of epoch.
return
self
.
iteration
/
len
(
self
.
batch_indices
)
@
property
def
previous_epoch_detail
(
self
):
if
self
.
_previous_epoch_detail
<
0
:
return
None
return
self
.
_previous_epoch_detail
def
serialize
(
self
,
serializer
):
# It is important to serialize the state to be recovered on resume.
self
.
iteration
=
serializer
(
"iteration"
,
self
.
iteration
)
self
.
epoch
=
serializer
(
"epoch"
,
self
.
epoch
)
try
:
self
.
_previous_epoch_detail
=
serializer
(
"previous_epoch_detail"
,
self
.
_previous_epoch_detail
)
except
KeyError
:
# guess previous_epoch_detail for older version
self
.
_previous_epoch_detail
=
self
.
epoch
+
(
self
.
current_position
-
1
)
/
len
(
self
.
batch_indices
)
if
self
.
epoch_detail
>
0
:
self
.
_previous_epoch_detail
=
max
(
self
.
_previous_epoch_detail
,
0.0
)
else
:
self
.
_previous_epoch_detail
=
-
1.0
class
MakeSymlinkToBestModel
(
extension
.
Extension
):
"""Extension that makes a symbolic link to the best model
:param str key: Key of value
:param str prefix: Prefix of model files and link target
:param str suffix: Suffix of link target
"""
def
__init__
(
self
,
key
,
prefix
=
"model"
,
suffix
=
"best"
):
super
(
MakeSymlinkToBestModel
,
self
).
__init__
()
self
.
best_model
=
-
1
self
.
min_loss
=
0.0
self
.
key
=
key
self
.
prefix
=
prefix
self
.
suffix
=
suffix
def
__call__
(
self
,
trainer
):
observation
=
trainer
.
observation
if
self
.
key
in
observation
:
loss
=
observation
[
self
.
key
]
if
self
.
best_model
==
-
1
or
loss
<
self
.
min_loss
:
self
.
min_loss
=
loss
self
.
best_model
=
trainer
.
updater
.
epoch
src
=
"%s.%d"
%
(
self
.
prefix
,
self
.
best_model
)
dest
=
os
.
path
.
join
(
trainer
.
out
,
"%s.%s"
%
(
self
.
prefix
,
self
.
suffix
))
if
os
.
path
.
lexists
(
dest
):
os
.
remove
(
dest
)
os
.
symlink
(
src
,
dest
)
logging
.
info
(
"best model is "
+
src
)
def
serialize
(
self
,
serializer
):
if
isinstance
(
serializer
,
chainer
.
serializer
.
Serializer
):
serializer
(
"_best_model"
,
self
.
best_model
)
serializer
(
"_min_loss"
,
self
.
min_loss
)
serializer
(
"_key"
,
self
.
key
)
serializer
(
"_prefix"
,
self
.
prefix
)
serializer
(
"_suffix"
,
self
.
suffix
)
else
:
self
.
best_model
=
serializer
(
"_best_model"
,
-
1
)
self
.
min_loss
=
serializer
(
"_min_loss"
,
0.0
)
self
.
key
=
serializer
(
"_key"
,
""
)
self
.
prefix
=
serializer
(
"_prefix"
,
"model"
)
self
.
suffix
=
serializer
(
"_suffix"
,
"best"
)
# TODO(Hori): currently it only works with character-word level LM.
# need to consider any types of subwords-to-word mapping.
def
make_lexical_tree
(
word_dict
,
subword_dict
,
word_unk
):
"""Make a lexical tree to compute word-level probabilities"""
# node [dict(subword_id -> node), word_id, word_set[start-1, end]]
root
=
[{},
-
1
,
None
]
for
w
,
wid
in
word_dict
.
items
():
if
wid
>
0
and
wid
!=
word_unk
:
# skip <blank> and <unk>
if
True
in
[
c
not
in
subword_dict
for
c
in
w
]:
# skip unknown subword
continue
succ
=
root
[
0
]
# get successors from root node
for
i
,
c
in
enumerate
(
w
):
cid
=
subword_dict
[
c
]
if
cid
not
in
succ
:
# if next node does not exist, make a new node
succ
[
cid
]
=
[{},
-
1
,
(
wid
-
1
,
wid
)]
else
:
prev
=
succ
[
cid
][
2
]
succ
[
cid
][
2
]
=
(
min
(
prev
[
0
],
wid
-
1
),
max
(
prev
[
1
],
wid
))
if
i
==
len
(
w
)
-
1
:
# if word end, set word id
succ
[
cid
][
1
]
=
wid
succ
=
succ
[
cid
][
0
]
# move to the child successors
return
root
conformer/espnet-v.202304_20240621/build/lib/espnet/lm/pytorch_backend/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/lm/pytorch_backend/extlm.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
espnet.lm.lm_utils
import
make_lexical_tree
from
espnet.nets.pytorch_backend.nets_utils
import
to_device
# Definition of a multi-level (subword/word) language model
class
MultiLevelLM
(
nn
.
Module
):
logzero
=
-
10000000000.0
zero
=
1.0e-10
def
__init__
(
self
,
wordlm
,
subwordlm
,
word_dict
,
subword_dict
,
subwordlm_weight
=
0.8
,
oov_penalty
=
1.0
,
open_vocab
=
True
,
):
super
(
MultiLevelLM
,
self
).
__init__
()
self
.
wordlm
=
wordlm
self
.
subwordlm
=
subwordlm
self
.
word_eos
=
word_dict
[
"<eos>"
]
self
.
word_unk
=
word_dict
[
"<unk>"
]
self
.
var_word_eos
=
torch
.
LongTensor
([
self
.
word_eos
])
self
.
var_word_unk
=
torch
.
LongTensor
([
self
.
word_unk
])
self
.
space
=
subword_dict
[
"<space>"
]
self
.
eos
=
subword_dict
[
"<eos>"
]
self
.
lexroot
=
make_lexical_tree
(
word_dict
,
subword_dict
,
self
.
word_unk
)
self
.
log_oov_penalty
=
math
.
log
(
oov_penalty
)
self
.
open_vocab
=
open_vocab
self
.
subword_dict_size
=
len
(
subword_dict
)
self
.
subwordlm_weight
=
subwordlm_weight
self
.
normalized
=
True
def
forward
(
self
,
state
,
x
):
# update state with input label x
if
state
is
None
:
# make initial states and log-prob vectors
self
.
var_word_eos
=
to_device
(
x
,
self
.
var_word_eos
)
self
.
var_word_unk
=
to_device
(
x
,
self
.
var_word_eos
)
wlm_state
,
z_wlm
=
self
.
wordlm
(
None
,
self
.
var_word_eos
)
wlm_logprobs
=
F
.
log_softmax
(
z_wlm
,
dim
=
1
)
clm_state
,
z_clm
=
self
.
subwordlm
(
None
,
x
)
log_y
=
F
.
log_softmax
(
z_clm
,
dim
=
1
)
*
self
.
subwordlm_weight
new_node
=
self
.
lexroot
clm_logprob
=
0.0
xi
=
self
.
space
else
:
clm_state
,
wlm_state
,
wlm_logprobs
,
node
,
log_y
,
clm_logprob
=
state
xi
=
int
(
x
)
if
xi
==
self
.
space
:
# inter-word transition
if
node
is
not
None
and
node
[
1
]
>=
0
:
# check if the node is word end
w
=
to_device
(
x
,
torch
.
LongTensor
([
node
[
1
]]))
else
:
# this node is not a word end, which means <unk>
w
=
self
.
var_word_unk
# update wordlm state and log-prob vector
wlm_state
,
z_wlm
=
self
.
wordlm
(
wlm_state
,
w
)
wlm_logprobs
=
F
.
log_softmax
(
z_wlm
,
dim
=
1
)
new_node
=
self
.
lexroot
# move to the tree root
clm_logprob
=
0.0
elif
node
is
not
None
and
xi
in
node
[
0
]:
# intra-word transition
new_node
=
node
[
0
][
xi
]
clm_logprob
+=
log_y
[
0
,
xi
]
elif
self
.
open_vocab
:
# if no path in the tree, enter open-vocabulary mode
new_node
=
None
clm_logprob
+=
log_y
[
0
,
xi
]
else
:
# if open_vocab flag is disabled, return 0 probabilities
log_y
=
to_device
(
x
,
torch
.
full
((
1
,
self
.
subword_dict_size
),
self
.
logzero
)
)
return
(
clm_state
,
wlm_state
,
wlm_logprobs
,
None
,
log_y
,
0.0
),
log_y
clm_state
,
z_clm
=
self
.
subwordlm
(
clm_state
,
x
)
log_y
=
F
.
log_softmax
(
z_clm
,
dim
=
1
)
*
self
.
subwordlm_weight
# apply word-level probabilies for <space> and <eos> labels
if
xi
!=
self
.
space
:
if
new_node
is
not
None
and
new_node
[
1
]
>=
0
:
# if new node is word end
wlm_logprob
=
wlm_logprobs
[:,
new_node
[
1
]]
-
clm_logprob
else
:
wlm_logprob
=
wlm_logprobs
[:,
self
.
word_unk
]
+
self
.
log_oov_penalty
log_y
[:,
self
.
space
]
=
wlm_logprob
log_y
[:,
self
.
eos
]
=
wlm_logprob
else
:
log_y
[:,
self
.
space
]
=
self
.
logzero
log_y
[:,
self
.
eos
]
=
self
.
logzero
return
(
(
clm_state
,
wlm_state
,
wlm_logprobs
,
new_node
,
log_y
,
float
(
clm_logprob
)),
log_y
,
)
def
final
(
self
,
state
):
clm_state
,
wlm_state
,
wlm_logprobs
,
node
,
log_y
,
clm_logprob
=
state
if
node
is
not
None
and
node
[
1
]
>=
0
:
# check if the node is word end
w
=
to_device
(
wlm_logprobs
,
torch
.
LongTensor
([
node
[
1
]]))
else
:
# this node is not a word end, which means <unk>
w
=
self
.
var_word_unk
wlm_state
,
z_wlm
=
self
.
wordlm
(
wlm_state
,
w
)
return
float
(
F
.
log_softmax
(
z_wlm
,
dim
=
1
)[:,
self
.
word_eos
])
# Definition of a look-ahead word language model
class
LookAheadWordLM
(
nn
.
Module
):
logzero
=
-
10000000000.0
zero
=
1.0e-10
def
__init__
(
self
,
wordlm
,
word_dict
,
subword_dict
,
oov_penalty
=
0.0001
,
open_vocab
=
True
):
super
(
LookAheadWordLM
,
self
).
__init__
()
self
.
wordlm
=
wordlm
self
.
word_eos
=
word_dict
[
"<eos>"
]
self
.
word_unk
=
word_dict
[
"<unk>"
]
self
.
var_word_eos
=
torch
.
LongTensor
([
self
.
word_eos
])
self
.
var_word_unk
=
torch
.
LongTensor
([
self
.
word_unk
])
self
.
space
=
subword_dict
[
"<space>"
]
self
.
eos
=
subword_dict
[
"<eos>"
]
self
.
lexroot
=
make_lexical_tree
(
word_dict
,
subword_dict
,
self
.
word_unk
)
self
.
oov_penalty
=
oov_penalty
self
.
open_vocab
=
open_vocab
self
.
subword_dict_size
=
len
(
subword_dict
)
self
.
zero_tensor
=
torch
.
FloatTensor
([
self
.
zero
])
self
.
normalized
=
True
def
forward
(
self
,
state
,
x
):
# update state with input label x
if
state
is
None
:
# make initial states and cumlative probability vector
self
.
var_word_eos
=
to_device
(
x
,
self
.
var_word_eos
)
self
.
var_word_unk
=
to_device
(
x
,
self
.
var_word_eos
)
self
.
zero_tensor
=
to_device
(
x
,
self
.
zero_tensor
)
wlm_state
,
z_wlm
=
self
.
wordlm
(
None
,
self
.
var_word_eos
)
cumsum_probs
=
torch
.
cumsum
(
F
.
softmax
(
z_wlm
,
dim
=
1
),
dim
=
1
)
new_node
=
self
.
lexroot
xi
=
self
.
space
else
:
wlm_state
,
cumsum_probs
,
node
=
state
xi
=
int
(
x
)
if
xi
==
self
.
space
:
# inter-word transition
if
node
is
not
None
and
node
[
1
]
>=
0
:
# check if the node is word end
w
=
to_device
(
x
,
torch
.
LongTensor
([
node
[
1
]]))
else
:
# this node is not a word end, which means <unk>
w
=
self
.
var_word_unk
# update wordlm state and cumlative probability vector
wlm_state
,
z_wlm
=
self
.
wordlm
(
wlm_state
,
w
)
cumsum_probs
=
torch
.
cumsum
(
F
.
softmax
(
z_wlm
,
dim
=
1
),
dim
=
1
)
new_node
=
self
.
lexroot
# move to the tree root
elif
node
is
not
None
and
xi
in
node
[
0
]:
# intra-word transition
new_node
=
node
[
0
][
xi
]
elif
self
.
open_vocab
:
# if no path in the tree, enter open-vocabulary mode
new_node
=
None
else
:
# if open_vocab flag is disabled, return 0 probabilities
log_y
=
to_device
(
x
,
torch
.
full
((
1
,
self
.
subword_dict_size
),
self
.
logzero
)
)
return
(
wlm_state
,
None
,
None
),
log_y
if
new_node
is
not
None
:
succ
,
wid
,
wids
=
new_node
# compute parent node probability
sum_prob
=
(
(
cumsum_probs
[:,
wids
[
1
]]
-
cumsum_probs
[:,
wids
[
0
]])
if
wids
is
not
None
else
1.0
)
if
sum_prob
<
self
.
zero
:
log_y
=
to_device
(
x
,
torch
.
full
((
1
,
self
.
subword_dict_size
),
self
.
logzero
)
)
return
(
wlm_state
,
cumsum_probs
,
new_node
),
log_y
# set <unk> probability as a default value
unk_prob
=
(
cumsum_probs
[:,
self
.
word_unk
]
-
cumsum_probs
[:,
self
.
word_unk
-
1
]
)
y
=
to_device
(
x
,
torch
.
full
(
(
1
,
self
.
subword_dict_size
),
float
(
unk_prob
)
*
self
.
oov_penalty
),
)
# compute transition probabilities to child nodes
for
cid
,
nd
in
succ
.
items
():
y
[:,
cid
]
=
(
cumsum_probs
[:,
nd
[
2
][
1
]]
-
cumsum_probs
[:,
nd
[
2
][
0
]]
)
/
sum_prob
# apply word-level probabilies for <space> and <eos> labels
if
wid
>=
0
:
wlm_prob
=
(
cumsum_probs
[:,
wid
]
-
cumsum_probs
[:,
wid
-
1
])
/
sum_prob
y
[:,
self
.
space
]
=
wlm_prob
y
[:,
self
.
eos
]
=
wlm_prob
elif
xi
==
self
.
space
:
y
[:,
self
.
space
]
=
self
.
zero
y
[:,
self
.
eos
]
=
self
.
zero
log_y
=
torch
.
log
(
torch
.
max
(
y
,
self
.
zero_tensor
))
# clip to avoid log(0)
else
:
# if no path in the tree, transition probability is one
log_y
=
to_device
(
x
,
torch
.
zeros
(
1
,
self
.
subword_dict_size
))
return
(
wlm_state
,
cumsum_probs
,
new_node
),
log_y
def
final
(
self
,
state
):
wlm_state
,
cumsum_probs
,
node
=
state
if
node
is
not
None
and
node
[
1
]
>=
0
:
# check if the node is word end
w
=
to_device
(
cumsum_probs
,
torch
.
LongTensor
([
node
[
1
]]))
else
:
# this node is not a word end, which means <unk>
w
=
self
.
var_word_unk
wlm_state
,
z_wlm
=
self
.
wordlm
(
wlm_state
,
w
)
return
float
(
F
.
log_softmax
(
z_wlm
,
dim
=
1
)[:,
self
.
word_eos
])
conformer/espnet-v.202304_20240621/build/lib/espnet/lm/pytorch_backend/lm.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
# This code is ported from the following implementation written in Torch.
# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py
"""LM training in pytorch."""
import
copy
import
json
import
logging
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
chainer
import
Chain
,
reporter
,
training
from
chainer.dataset
import
convert
from
chainer.training
import
extensions
from
torch.nn.parallel
import
data_parallel
from
espnet.asr.asr_utils
import
(
snapshot_object
,
torch_load
,
torch_resume
,
torch_snapshot
,
)
from
espnet.lm.lm_utils
import
(
MakeSymlinkToBestModel
,
ParallelSentenceIterator
,
count_tokens
,
load_dataset
,
read_tokens
,
)
from
espnet.nets.lm_interface
import
LMInterface
,
dynamic_import_lm
from
espnet.optimizer.factory
import
dynamic_import_optimizer
from
espnet.scheduler.pytorch
import
PyTorchScheduler
from
espnet.scheduler.scheduler
import
dynamic_import_scheduler
from
espnet.utils.deterministic_utils
import
set_deterministic_pytorch
from
espnet.utils.training.evaluator
import
BaseEvaluator
from
espnet.utils.training.iterators
import
ShufflingEnabler
from
espnet.utils.training.tensorboard_logger
import
TensorboardLogger
from
espnet.utils.training.train_utils
import
check_early_stop
,
set_early_stop
def
compute_perplexity
(
result
):
"""Compute and add the perplexity to the LogReport.
:param dict result: The current observations
"""
# Routine to rewrite the result dictionary of LogReport to add perplexity values
result
[
"perplexity"
]
=
np
.
exp
(
result
[
"main/nll"
]
/
result
[
"main/count"
])
if
"validation/main/nll"
in
result
:
result
[
"val_perplexity"
]
=
np
.
exp
(
result
[
"validation/main/nll"
]
/
result
[
"validation/main/count"
]
)
class
Reporter
(
Chain
):
"""Dummy module to use chainer's trainer."""
def
report
(
self
,
loss
):
"""Report nothing."""
pass
def
concat_examples
(
batch
,
device
=
None
,
padding
=
None
):
"""Concat examples in minibatch.
:param np.ndarray batch: The batch to concatenate
:param int device: The device to send to
:param Tuple[int,int] padding: The padding to use
:return: (inputs, targets)
:rtype (torch.Tensor, torch.Tensor)
"""
x
,
t
=
convert
.
concat_examples
(
batch
,
padding
=
padding
)
x
=
torch
.
from_numpy
(
x
)
t
=
torch
.
from_numpy
(
t
)
if
device
is
not
None
and
device
>=
0
:
x
=
x
.
cuda
(
device
)
t
=
t
.
cuda
(
device
)
return
x
,
t
class
BPTTUpdater
(
training
.
StandardUpdater
):
"""An updater for a pytorch LM."""
def
__init__
(
self
,
train_iter
,
model
,
optimizer
,
schedulers
,
device
,
gradclip
=
None
,
use_apex
=
False
,
accum_grad
=
1
,
):
"""Initialize class.
Args:
train_iter (chainer.dataset.Iterator): The train iterator
model (LMInterface) : The model to update
optimizer (torch.optim.Optimizer): The optimizer for training
schedulers (espnet.scheduler.scheduler.SchedulerInterface):
The schedulers of `optimizer`
device (int): The device id
gradclip (float): The gradient clipping value to use
use_apex (bool): The flag to use Apex in backprop.
accum_grad (int): The number of gradient accumulation.
"""
super
(
BPTTUpdater
,
self
).
__init__
(
train_iter
,
optimizer
)
self
.
model
=
model
self
.
device
=
device
self
.
gradclip
=
gradclip
self
.
use_apex
=
use_apex
self
.
scheduler
=
PyTorchScheduler
(
schedulers
,
optimizer
)
self
.
accum_grad
=
accum_grad
# The core part of the update routine can be customized by overriding.
def
update_core
(
self
):
"""Update the model."""
# When we pass one iterator and optimizer to StandardUpdater.__init__,
# they are automatically named 'main'.
train_iter
=
self
.
get_iterator
(
"main"
)
optimizer
=
self
.
get_optimizer
(
"main"
)
# Progress the dataset iterator for sentences at each iteration.
self
.
model
.
zero_grad
()
# Clear the parameter gradients
accum
=
{
"loss"
:
0.0
,
"nll"
:
0.0
,
"count"
:
0
}
for
_
in
range
(
self
.
accum_grad
):
batch
=
train_iter
.
__next__
()
# Concatenate the token IDs to matrices and send them to the device
# self.converter does this job
# (it is chainer.dataset.concat_examples by default)
x
,
t
=
concat_examples
(
batch
,
device
=
self
.
device
[
0
],
padding
=
(
0
,
-
100
))
if
self
.
device
[
0
]
==
-
1
:
loss
,
nll
,
count
=
self
.
model
(
x
,
t
)
else
:
# apex does not support torch.nn.DataParallel
loss
,
nll
,
count
=
data_parallel
(
self
.
model
,
(
x
,
t
),
self
.
device
)
# backward
loss
=
loss
.
mean
()
/
self
.
accum_grad
if
self
.
use_apex
:
from
apex
import
amp
with
amp
.
scale_loss
(
loss
,
optimizer
)
as
scaled_loss
:
scaled_loss
.
backward
()
else
:
loss
.
backward
()
# Backprop
# accumulate stats
accum
[
"loss"
]
+=
float
(
loss
)
accum
[
"nll"
]
+=
float
(
nll
.
sum
())
accum
[
"count"
]
+=
int
(
count
.
sum
())
for
k
,
v
in
accum
.
items
():
reporter
.
report
({
k
:
v
},
optimizer
.
target
)
if
self
.
gradclip
is
not
None
:
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
parameters
(),
self
.
gradclip
)
optimizer
.
step
()
# Update the parameters
self
.
scheduler
.
step
(
n_iter
=
self
.
iteration
)
class
LMEvaluator
(
BaseEvaluator
):
"""A custom evaluator for a pytorch LM."""
def
__init__
(
self
,
val_iter
,
eval_model
,
reporter
,
device
):
"""Initialize class.
:param chainer.dataset.Iterator val_iter : The validation iterator
:param LMInterface eval_model : The model to evaluate
:param chainer.Reporter reporter : The observations reporter
:param int device : The device id to use
"""
super
(
LMEvaluator
,
self
).
__init__
(
val_iter
,
reporter
,
device
=-
1
)
self
.
model
=
eval_model
self
.
device
=
device
def
evaluate
(
self
):
"""Evaluate the model."""
val_iter
=
self
.
get_iterator
(
"main"
)
loss
=
0
nll
=
0
count
=
0
self
.
model
.
eval
()
with
torch
.
no_grad
():
for
batch
in
copy
.
copy
(
val_iter
):
x
,
t
=
concat_examples
(
batch
,
device
=
self
.
device
[
0
],
padding
=
(
0
,
-
100
))
if
self
.
device
[
0
]
==
-
1
:
l
,
n
,
c
=
self
.
model
(
x
,
t
)
else
:
# apex does not support torch.nn.DataParallel
l
,
n
,
c
=
data_parallel
(
self
.
model
,
(
x
,
t
),
self
.
device
)
loss
+=
float
(
l
.
sum
())
nll
+=
float
(
n
.
sum
())
count
+=
int
(
c
.
sum
())
self
.
model
.
train
()
# report validation loss
observation
=
{}
with
reporter
.
report_scope
(
observation
):
reporter
.
report
({
"loss"
:
loss
},
self
.
model
.
reporter
)
reporter
.
report
({
"nll"
:
nll
},
self
.
model
.
reporter
)
reporter
.
report
({
"count"
:
count
},
self
.
model
.
reporter
)
return
observation
def
train
(
args
):
"""Train with the given args.
:param Namespace args: The program arguments
:param type model_class: LMInterface class for training
"""
model_class
=
dynamic_import_lm
(
args
.
model_module
,
args
.
backend
)
assert
issubclass
(
model_class
,
LMInterface
),
"model should implement LMInterface"
# display torch version
logging
.
info
(
"torch version = "
+
torch
.
__version__
)
set_deterministic_pytorch
(
args
)
# check cuda and cudnn availability
if
not
torch
.
cuda
.
is_available
():
logging
.
warning
(
"cuda is not available"
)
# get special label ids
unk
=
args
.
char_list_dict
[
"<unk>"
]
eos
=
args
.
char_list_dict
[
"<eos>"
]
# read tokens as a sequence of sentences
val
,
n_val_tokens
,
n_val_oovs
=
load_dataset
(
args
.
valid_label
,
args
.
char_list_dict
,
args
.
dump_hdf5_path
)
train
,
n_train_tokens
,
n_train_oovs
=
load_dataset
(
args
.
train_label
,
args
.
char_list_dict
,
args
.
dump_hdf5_path
)
logging
.
info
(
"#vocab = "
+
str
(
args
.
n_vocab
))
logging
.
info
(
"#sentences in the training data = "
+
str
(
len
(
train
)))
logging
.
info
(
"#tokens in the training data = "
+
str
(
n_train_tokens
))
logging
.
info
(
"oov rate in the training data = %.2f %%"
%
(
n_train_oovs
/
n_train_tokens
*
100
)
)
logging
.
info
(
"#sentences in the validation data = "
+
str
(
len
(
val
)))
logging
.
info
(
"#tokens in the validation data = "
+
str
(
n_val_tokens
))
logging
.
info
(
"oov rate in the validation data = %.2f %%"
%
(
n_val_oovs
/
n_val_tokens
*
100
)
)
use_sortagrad
=
args
.
sortagrad
==
-
1
or
args
.
sortagrad
>
0
# Create the dataset iterators
batch_size
=
args
.
batchsize
*
max
(
args
.
ngpu
,
1
)
if
batch_size
*
args
.
accum_grad
>
args
.
batchsize
:
logging
.
info
(
f
"batch size is automatically increased "
f
"(
{
args
.
batchsize
}
->
{
batch_size
*
args
.
accum_grad
}
)"
)
train_iter
=
ParallelSentenceIterator
(
train
,
batch_size
,
max_length
=
args
.
maxlen
,
sos
=
eos
,
eos
=
eos
,
shuffle
=
not
use_sortagrad
,
)
val_iter
=
ParallelSentenceIterator
(
val
,
batch_size
,
max_length
=
args
.
maxlen
,
sos
=
eos
,
eos
=
eos
,
repeat
=
False
)
epoch_iters
=
int
(
len
(
train_iter
.
batch_indices
)
/
args
.
accum_grad
)
logging
.
info
(
"#iterations per epoch = %d"
%
epoch_iters
)
logging
.
info
(
"#total iterations = "
+
str
(
args
.
epoch
*
epoch_iters
))
# Prepare an RNNLM model
if
args
.
train_dtype
in
(
"float16"
,
"float32"
,
"float64"
):
dtype
=
getattr
(
torch
,
args
.
train_dtype
)
else
:
dtype
=
torch
.
float32
model
=
model_class
(
args
.
n_vocab
,
args
).
to
(
dtype
=
dtype
)
if
args
.
ngpu
>
0
:
model
.
to
(
"cuda"
)
gpu_id
=
list
(
range
(
args
.
ngpu
))
else
:
gpu_id
=
[
-
1
]
# Save model conf to json
model_conf
=
args
.
outdir
+
"/model.json"
with
open
(
model_conf
,
"wb"
)
as
f
:
logging
.
info
(
"writing a model config file to "
+
model_conf
)
f
.
write
(
json
.
dumps
(
vars
(
args
),
indent
=
4
,
ensure_ascii
=
False
,
sort_keys
=
True
).
encode
(
"utf_8"
)
)
logging
.
warning
(
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))"
.
format
(
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()),
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
),
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
*
100.0
/
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()),
)
)
# Set up an optimizer
opt_class
=
dynamic_import_optimizer
(
args
.
opt
,
args
.
backend
)
optimizer
=
opt_class
.
from_args
(
model
.
parameters
(),
args
)
if
args
.
schedulers
is
None
:
schedulers
=
[]
else
:
schedulers
=
[
dynamic_import_scheduler
(
v
)(
k
,
args
)
for
k
,
v
in
args
.
schedulers
]
# setup apex.amp
if
args
.
train_dtype
in
(
"O0"
,
"O1"
,
"O2"
,
"O3"
):
try
:
from
apex
import
amp
except
ImportError
as
e
:
logging
.
error
(
f
"You need to install apex for --train-dtype
{
args
.
train_dtype
}
. "
"See https://github.com/NVIDIA/apex#linux"
)
raise
e
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
args
.
train_dtype
)
use_apex
=
True
else
:
use_apex
=
False
# FIXME: TOO DIRTY HACK
reporter
=
Reporter
()
setattr
(
model
,
"reporter"
,
reporter
)
setattr
(
optimizer
,
"target"
,
reporter
)
setattr
(
optimizer
,
"serialize"
,
lambda
s
:
reporter
.
serialize
(
s
))
updater
=
BPTTUpdater
(
train_iter
,
model
,
optimizer
,
schedulers
,
gpu_id
,
gradclip
=
args
.
gradclip
,
use_apex
=
use_apex
,
accum_grad
=
args
.
accum_grad
,
)
trainer
=
training
.
Trainer
(
updater
,
(
args
.
epoch
,
"epoch"
),
out
=
args
.
outdir
)
trainer
.
extend
(
LMEvaluator
(
val_iter
,
model
,
reporter
,
device
=
gpu_id
))
trainer
.
extend
(
extensions
.
LogReport
(
postprocess
=
compute_perplexity
,
trigger
=
(
args
.
report_interval_iters
,
"iteration"
),
)
)
trainer
.
extend
(
extensions
.
PrintReport
(
[
"epoch"
,
"iteration"
,
"main/loss"
,
"perplexity"
,
"val_perplexity"
,
"elapsed_time"
,
]
),
trigger
=
(
args
.
report_interval_iters
,
"iteration"
),
)
trainer
.
extend
(
extensions
.
ProgressBar
(
update_interval
=
args
.
report_interval_iters
))
# Save best models
trainer
.
extend
(
torch_snapshot
(
filename
=
"snapshot.ep.{.updater.epoch}"
))
trainer
.
extend
(
snapshot_object
(
model
,
"rnnlm.model.{.updater.epoch}"
))
# T.Hori: MinValueTrigger should be used, but it fails when resuming
trainer
.
extend
(
MakeSymlinkToBestModel
(
"validation/main/loss"
,
"rnnlm.model"
))
if
use_sortagrad
:
trainer
.
extend
(
ShufflingEnabler
([
train_iter
]),
trigger
=
(
args
.
sortagrad
if
args
.
sortagrad
!=
-
1
else
args
.
epoch
,
"epoch"
),
)
if
args
.
resume
:
logging
.
info
(
"resumed from %s"
%
args
.
resume
)
torch_resume
(
args
.
resume
,
trainer
)
set_early_stop
(
trainer
,
args
,
is_lm
=
True
)
if
args
.
tensorboard_dir
is
not
None
and
args
.
tensorboard_dir
!=
""
:
from
torch.utils.tensorboard
import
SummaryWriter
writer
=
SummaryWriter
(
args
.
tensorboard_dir
)
trainer
.
extend
(
TensorboardLogger
(
writer
),
trigger
=
(
args
.
report_interval_iters
,
"iteration"
)
)
trainer
.
run
()
check_early_stop
(
trainer
,
args
.
epoch
)
# compute perplexity for test set
if
args
.
test_label
:
logging
.
info
(
"test the best model"
)
torch_load
(
args
.
outdir
+
"/rnnlm.model.best"
,
model
)
test
=
read_tokens
(
args
.
test_label
,
args
.
char_list_dict
)
n_test_tokens
,
n_test_oovs
=
count_tokens
(
test
,
unk
)
logging
.
info
(
"#sentences in the test data = "
+
str
(
len
(
test
)))
logging
.
info
(
"#tokens in the test data = "
+
str
(
n_test_tokens
))
logging
.
info
(
"oov rate in the test data = %.2f %%"
%
(
n_test_oovs
/
n_test_tokens
*
100
)
)
test_iter
=
ParallelSentenceIterator
(
test
,
batch_size
,
max_length
=
args
.
maxlen
,
sos
=
eos
,
eos
=
eos
,
repeat
=
False
)
evaluator
=
LMEvaluator
(
test_iter
,
model
,
reporter
,
device
=
gpu_id
)
result
=
evaluator
()
compute_perplexity
(
result
)
logging
.
info
(
f
"test perplexity:
{
result
[
'perplexity'
]
}
"
)
conformer/espnet-v.202304_20240621/build/lib/espnet/mt/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/mt/mt_utils.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Utility functions for the text translation task."""
import
logging
# * ------------------ recognition related ------------------ *
def
parse_hypothesis
(
hyp
,
char_list
):
"""Parse hypothesis.
:param list hyp: recognition hypothesis
:param list char_list: list of characters
:return: recognition text string
:return: recognition token string
:return: recognition tokenid string
"""
# remove sos and get results
tokenid_as_list
=
list
(
map
(
int
,
hyp
[
"yseq"
][
1
:]))
token_as_list
=
[
char_list
[
idx
]
for
idx
in
tokenid_as_list
]
score
=
float
(
hyp
[
"score"
])
# convert to string
tokenid
=
" "
.
join
([
str
(
idx
)
for
idx
in
tokenid_as_list
])
token
=
" "
.
join
(
token_as_list
)
text
=
""
.
join
(
token_as_list
).
replace
(
"<space>"
,
" "
)
return
text
,
token
,
tokenid
,
score
def
add_results_to_json
(
js
,
nbest_hyps
,
char_list
):
"""Add N-best results to json.
:param dict js: groundtruth utterance dict
:param list nbest_hyps: list of hypothesis
:param list char_list: list of characters
:return: N-best results added utterance dict
"""
# copy old json info
new_js
=
dict
()
if
"utt2spk"
in
js
.
keys
():
new_js
[
"utt2spk"
]
=
js
[
"utt2spk"
]
new_js
[
"output"
]
=
[]
for
n
,
hyp
in
enumerate
(
nbest_hyps
,
1
):
# parse hypothesis
rec_text
,
rec_token
,
rec_tokenid
,
score
=
parse_hypothesis
(
hyp
,
char_list
)
# copy ground-truth
if
len
(
js
[
"output"
])
>
0
:
out_dic
=
dict
(
js
[
"output"
][
0
].
items
())
else
:
out_dic
=
{
"name"
:
""
}
# update name
out_dic
[
"name"
]
+=
"[%d]"
%
n
# add recognition results
out_dic
[
"rec_text"
]
=
rec_text
out_dic
[
"rec_token"
]
=
rec_token
out_dic
[
"rec_tokenid"
]
=
rec_tokenid
out_dic
[
"score"
]
=
score
# add source reference
out_dic
[
"text_src"
]
=
js
[
"output"
][
1
][
"text"
]
out_dic
[
"token_src"
]
=
js
[
"output"
][
1
][
"token"
]
out_dic
[
"tokenid_src"
]
=
js
[
"output"
][
1
][
"tokenid"
]
# add to list of N-best result dicts
new_js
[
"output"
].
append
(
out_dic
)
# show 1-best result
if
n
==
1
:
if
"text"
in
out_dic
.
keys
():
logging
.
info
(
"groundtruth: %s"
%
out_dic
[
"text"
])
logging
.
info
(
"prediction : %s"
%
out_dic
[
"rec_text"
])
logging
.
info
(
"source : %s"
%
out_dic
[
"token_src"
])
return
new_js
conformer/espnet-v.202304_20240621/build/lib/espnet/mt/pytorch_backend/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/mt/pytorch_backend/mt.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Training/decoding definition for the text translation task."""
import
itertools
import
json
import
logging
import
os
import
numpy
as
np
import
torch
from
chainer
import
training
from
chainer.training
import
extensions
from
espnet.asr.asr_utils
import
(
CompareValueTrigger
,
adadelta_eps_decay
,
adam_lr_decay
,
add_results_to_json
,
restore_snapshot
,
snapshot_object
,
torch_load
,
torch_resume
,
torch_snapshot
,
)
from
espnet.asr.pytorch_backend.asr
import
(
CustomEvaluator
,
CustomUpdater
,
load_trained_model
,
)
from
espnet.nets.mt_interface
import
MTInterface
from
espnet.nets.pytorch_backend.e2e_asr
import
pad_list
from
espnet.utils.dataset
import
ChainerDataLoader
,
TransformDataset
from
espnet.utils.deterministic_utils
import
set_deterministic_pytorch
from
espnet.utils.dynamic_import
import
dynamic_import
from
espnet.utils.io_utils
import
LoadInputsAndTargets
from
espnet.utils.training.batchfy
import
make_batchset
from
espnet.utils.training.iterators
import
ShufflingEnabler
from
espnet.utils.training.tensorboard_logger
import
TensorboardLogger
from
espnet.utils.training.train_utils
import
check_early_stop
,
set_early_stop
class
CustomConverter
(
object
):
"""Custom batch converter for Pytorch."""
def
__init__
(
self
):
"""Construct a CustomConverter object."""
self
.
ignore_id
=
-
1
self
.
pad
=
0
# NOTE: we reserve index:0 for <pad> although this is reserved for a blank class
# in ASR. However,
# blank labels are not used in NMT. To keep the vocabulary size,
# we use index:0 for padding instead of adding one more class.
def
__call__
(
self
,
batch
,
device
=
torch
.
device
(
"cpu"
)):
"""Transform a batch and send it to a device.
Args:
batch (list): The batch to transform.
device (torch.device): The device to send to.
Returns:
tuple(torch.Tensor, torch.Tensor, torch.Tensor)
"""
# batch should be located in list
assert
len
(
batch
)
==
1
xs
,
ys
=
batch
[
0
]
# get batch of lengths of input sequences
ilens
=
np
.
array
([
x
.
shape
[
0
]
for
x
in
xs
])
# perform padding and convert to tensor
xs_pad
=
pad_list
([
torch
.
from_numpy
(
x
).
long
()
for
x
in
xs
],
self
.
pad
).
to
(
device
)
ilens
=
torch
.
from_numpy
(
ilens
).
to
(
device
)
ys_pad
=
pad_list
([
torch
.
from_numpy
(
y
).
long
()
for
y
in
ys
],
self
.
ignore_id
).
to
(
device
)
return
xs_pad
,
ilens
,
ys_pad
def
train
(
args
):
"""Train with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch
(
args
)
# check cuda availability
if
not
torch
.
cuda
.
is_available
():
logging
.
warning
(
"cuda is not available"
)
# get input and output dimension info
with
open
(
args
.
valid_json
,
"rb"
)
as
f
:
valid_json
=
json
.
load
(
f
)[
"utts"
]
utts
=
list
(
valid_json
.
keys
())
idim
=
int
(
valid_json
[
utts
[
0
]][
"output"
][
1
][
"shape"
][
1
])
odim
=
int
(
valid_json
[
utts
[
0
]][
"output"
][
0
][
"shape"
][
1
])
logging
.
info
(
"#input dims : "
+
str
(
idim
))
logging
.
info
(
"#output dims: "
+
str
(
odim
))
# specify model architecture
model_class
=
dynamic_import
(
args
.
model_module
)
model
=
model_class
(
idim
,
odim
,
args
)
assert
isinstance
(
model
,
MTInterface
)
# write model config
if
not
os
.
path
.
exists
(
args
.
outdir
):
os
.
makedirs
(
args
.
outdir
)
model_conf
=
args
.
outdir
+
"/model.json"
with
open
(
model_conf
,
"wb"
)
as
f
:
logging
.
info
(
"writing a model config file to "
+
model_conf
)
f
.
write
(
json
.
dumps
(
(
idim
,
odim
,
vars
(
args
)),
indent
=
4
,
ensure_ascii
=
False
,
sort_keys
=
True
).
encode
(
"utf_8"
)
)
for
key
in
sorted
(
vars
(
args
).
keys
()):
logging
.
info
(
"ARGS: "
+
key
+
": "
+
str
(
vars
(
args
)[
key
]))
reporter
=
model
.
reporter
# check the use of multi-gpu
if
args
.
ngpu
>
1
:
if
args
.
batch_size
!=
0
:
logging
.
warning
(
"batch size is automatically increased (%d -> %d)"
%
(
args
.
batch_size
,
args
.
batch_size
*
args
.
ngpu
)
)
args
.
batch_size
*=
args
.
ngpu
# set torch device
device
=
torch
.
device
(
"cuda"
if
args
.
ngpu
>
0
else
"cpu"
)
if
args
.
train_dtype
in
(
"float16"
,
"float32"
,
"float64"
):
dtype
=
getattr
(
torch
,
args
.
train_dtype
)
else
:
dtype
=
torch
.
float32
model
=
model
.
to
(
device
=
device
,
dtype
=
dtype
)
logging
.
warning
(
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))"
.
format
(
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()),
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
),
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()
if
p
.
requires_grad
)
*
100.0
/
sum
(
p
.
numel
()
for
p
in
model
.
parameters
()),
)
)
# Setup an optimizer
if
args
.
opt
==
"adadelta"
:
optimizer
=
torch
.
optim
.
Adadelta
(
model
.
parameters
(),
rho
=
0.95
,
eps
=
args
.
eps
,
weight_decay
=
args
.
weight_decay
)
elif
args
.
opt
==
"adam"
:
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
)
elif
args
.
opt
==
"noam"
:
from
espnet.nets.pytorch_backend.transformer.optimizer
import
get_std_opt
optimizer
=
get_std_opt
(
model
.
parameters
(),
args
.
adim
,
args
.
transformer_warmup_steps
,
args
.
transformer_lr
,
)
else
:
raise
NotImplementedError
(
"unknown optimizer: "
+
args
.
opt
)
# setup apex.amp
if
args
.
train_dtype
in
(
"O0"
,
"O1"
,
"O2"
,
"O3"
):
try
:
from
apex
import
amp
except
ImportError
as
e
:
logging
.
error
(
f
"You need to install apex for --train-dtype
{
args
.
train_dtype
}
. "
"See https://github.com/NVIDIA/apex#linux"
)
raise
e
if
args
.
opt
==
"noam"
:
model
,
optimizer
.
optimizer
=
amp
.
initialize
(
model
,
optimizer
.
optimizer
,
opt_level
=
args
.
train_dtype
)
else
:
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
args
.
train_dtype
)
use_apex
=
True
else
:
use_apex
=
False
# FIXME: TOO DIRTY HACK
setattr
(
optimizer
,
"target"
,
reporter
)
setattr
(
optimizer
,
"serialize"
,
lambda
s
:
reporter
.
serialize
(
s
))
# Setup a converter
converter
=
CustomConverter
()
# read json data
with
open
(
args
.
train_json
,
"rb"
)
as
f
:
train_json
=
json
.
load
(
f
)[
"utts"
]
with
open
(
args
.
valid_json
,
"rb"
)
as
f
:
valid_json
=
json
.
load
(
f
)[
"utts"
]
use_sortagrad
=
args
.
sortagrad
==
-
1
or
args
.
sortagrad
>
0
# make minibatch list (variable length)
train
=
make_batchset
(
train_json
,
args
.
batch_size
,
args
.
maxlen_in
,
args
.
maxlen_out
,
args
.
minibatches
,
min_batch_size
=
args
.
ngpu
if
args
.
ngpu
>
1
else
1
,
shortest_first
=
use_sortagrad
,
count
=
args
.
batch_count
,
batch_bins
=
args
.
batch_bins
,
batch_frames_in
=
args
.
batch_frames_in
,
batch_frames_out
=
args
.
batch_frames_out
,
batch_frames_inout
=
args
.
batch_frames_inout
,
mt
=
True
,
iaxis
=
1
,
oaxis
=
0
,
)
valid
=
make_batchset
(
valid_json
,
args
.
batch_size
,
args
.
maxlen_in
,
args
.
maxlen_out
,
args
.
minibatches
,
min_batch_size
=
args
.
ngpu
if
args
.
ngpu
>
1
else
1
,
count
=
args
.
batch_count
,
batch_bins
=
args
.
batch_bins
,
batch_frames_in
=
args
.
batch_frames_in
,
batch_frames_out
=
args
.
batch_frames_out
,
batch_frames_inout
=
args
.
batch_frames_inout
,
mt
=
True
,
iaxis
=
1
,
oaxis
=
0
,
)
load_tr
=
LoadInputsAndTargets
(
mode
=
"mt"
,
load_output
=
True
)
load_cv
=
LoadInputsAndTargets
(
mode
=
"mt"
,
load_output
=
True
)
# hack to make batchsize argument as 1
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
train_iter
=
ChainerDataLoader
(
dataset
=
TransformDataset
(
train
,
lambda
data
:
converter
([
load_tr
(
data
)])),
batch_size
=
1
,
num_workers
=
args
.
n_iter_processes
,
shuffle
=
not
use_sortagrad
,
collate_fn
=
lambda
x
:
x
[
0
],
)
valid_iter
=
ChainerDataLoader
(
dataset
=
TransformDataset
(
valid
,
lambda
data
:
converter
([
load_cv
(
data
)])),
batch_size
=
1
,
shuffle
=
False
,
collate_fn
=
lambda
x
:
x
[
0
],
num_workers
=
args
.
n_iter_processes
,
)
# Set up a trainer
updater
=
CustomUpdater
(
model
,
args
.
grad_clip
,
{
"main"
:
train_iter
},
optimizer
,
device
,
args
.
ngpu
,
False
,
args
.
accum_grad
,
use_apex
=
use_apex
,
)
trainer
=
training
.
Trainer
(
updater
,
(
args
.
epochs
,
"epoch"
),
out
=
args
.
outdir
)
if
use_sortagrad
:
trainer
.
extend
(
ShufflingEnabler
([
train_iter
]),
trigger
=
(
args
.
sortagrad
if
args
.
sortagrad
!=
-
1
else
args
.
epochs
,
"epoch"
),
)
# Resume from a snapshot
if
args
.
resume
:
logging
.
info
(
"resumed from %s"
%
args
.
resume
)
torch_resume
(
args
.
resume
,
trainer
)
# Evaluate the model with the test dataset for each epoch
if
args
.
save_interval_iters
>
0
:
trainer
.
extend
(
CustomEvaluator
(
model
,
{
"main"
:
valid_iter
},
reporter
,
device
,
args
.
ngpu
),
trigger
=
(
args
.
save_interval_iters
,
"iteration"
),
)
else
:
trainer
.
extend
(
CustomEvaluator
(
model
,
{
"main"
:
valid_iter
},
reporter
,
device
,
args
.
ngpu
)
)
# Save attention weight each epoch
if
args
.
num_save_attention
>
0
:
# NOTE: sort it by output lengths
data
=
sorted
(
list
(
valid_json
.
items
())[:
args
.
num_save_attention
],
key
=
lambda
x
:
int
(
x
[
1
][
"output"
][
0
][
"shape"
][
0
]),
reverse
=
True
,
)
if
hasattr
(
model
,
"module"
):
att_vis_fn
=
model
.
module
.
calculate_all_attentions
plot_class
=
model
.
module
.
attention_plot_class
else
:
att_vis_fn
=
model
.
calculate_all_attentions
plot_class
=
model
.
attention_plot_class
att_reporter
=
plot_class
(
att_vis_fn
,
data
,
args
.
outdir
+
"/att_ws"
,
converter
=
converter
,
transform
=
load_cv
,
device
=
device
,
ikey
=
"output"
,
iaxis
=
1
,
)
trainer
.
extend
(
att_reporter
,
trigger
=
(
1
,
"epoch"
))
else
:
att_reporter
=
None
# Make a plot for training and validation values
trainer
.
extend
(
extensions
.
PlotReport
(
[
"main/loss"
,
"validation/main/loss"
],
"epoch"
,
file_name
=
"loss.png"
)
)
trainer
.
extend
(
extensions
.
PlotReport
(
[
"main/acc"
,
"validation/main/acc"
],
"epoch"
,
file_name
=
"acc.png"
)
)
trainer
.
extend
(
extensions
.
PlotReport
(
[
"main/ppl"
,
"validation/main/ppl"
],
"epoch"
,
file_name
=
"ppl.png"
)
)
trainer
.
extend
(
extensions
.
PlotReport
(
[
"main/bleu"
,
"validation/main/bleu"
],
"epoch"
,
file_name
=
"bleu.png"
)
)
# Save best models
trainer
.
extend
(
snapshot_object
(
model
,
"model.loss.best"
),
trigger
=
training
.
triggers
.
MinValueTrigger
(
"validation/main/loss"
),
)
trainer
.
extend
(
snapshot_object
(
model
,
"model.acc.best"
),
trigger
=
training
.
triggers
.
MaxValueTrigger
(
"validation/main/acc"
),
)
# save snapshot which contains model and optimizer states
if
args
.
save_interval_iters
>
0
:
trainer
.
extend
(
torch_snapshot
(
filename
=
"snapshot.iter.{.updater.iteration}"
),
trigger
=
(
args
.
save_interval_iters
,
"iteration"
),
)
else
:
trainer
.
extend
(
torch_snapshot
(),
trigger
=
(
1
,
"epoch"
))
# epsilon decay in the optimizer
if
args
.
opt
==
"adadelta"
:
if
args
.
criterion
==
"acc"
:
trainer
.
extend
(
restore_snapshot
(
model
,
args
.
outdir
+
"/model.acc.best"
,
load_fn
=
torch_load
),
trigger
=
CompareValueTrigger
(
"validation/main/acc"
,
lambda
best_value
,
current_value
:
best_value
>
current_value
,
),
)
trainer
.
extend
(
adadelta_eps_decay
(
args
.
eps_decay
),
trigger
=
CompareValueTrigger
(
"validation/main/acc"
,
lambda
best_value
,
current_value
:
best_value
>
current_value
,
),
)
elif
args
.
criterion
==
"loss"
:
trainer
.
extend
(
restore_snapshot
(
model
,
args
.
outdir
+
"/model.loss.best"
,
load_fn
=
torch_load
),
trigger
=
CompareValueTrigger
(
"validation/main/loss"
,
lambda
best_value
,
current_value
:
best_value
<
current_value
,
),
)
trainer
.
extend
(
adadelta_eps_decay
(
args
.
eps_decay
),
trigger
=
CompareValueTrigger
(
"validation/main/loss"
,
lambda
best_value
,
current_value
:
best_value
<
current_value
,
),
)
elif
args
.
opt
==
"adam"
:
if
args
.
criterion
==
"acc"
:
trainer
.
extend
(
restore_snapshot
(
model
,
args
.
outdir
+
"/model.acc.best"
,
load_fn
=
torch_load
),
trigger
=
CompareValueTrigger
(
"validation/main/acc"
,
lambda
best_value
,
current_value
:
best_value
>
current_value
,
),
)
trainer
.
extend
(
adam_lr_decay
(
args
.
lr_decay
),
trigger
=
CompareValueTrigger
(
"validation/main/acc"
,
lambda
best_value
,
current_value
:
best_value
>
current_value
,
),
)
elif
args
.
criterion
==
"loss"
:
trainer
.
extend
(
restore_snapshot
(
model
,
args
.
outdir
+
"/model.loss.best"
,
load_fn
=
torch_load
),
trigger
=
CompareValueTrigger
(
"validation/main/loss"
,
lambda
best_value
,
current_value
:
best_value
<
current_value
,
),
)
trainer
.
extend
(
adam_lr_decay
(
args
.
lr_decay
),
trigger
=
CompareValueTrigger
(
"validation/main/loss"
,
lambda
best_value
,
current_value
:
best_value
<
current_value
,
),
)
# Write a log of evaluation statistics for each epoch
trainer
.
extend
(
extensions
.
LogReport
(
trigger
=
(
args
.
report_interval_iters
,
"iteration"
))
)
report_keys
=
[
"epoch"
,
"iteration"
,
"main/loss"
,
"validation/main/loss"
,
"main/acc"
,
"validation/main/acc"
,
"main/ppl"
,
"validation/main/ppl"
,
"elapsed_time"
,
]
if
args
.
opt
==
"adadelta"
:
trainer
.
extend
(
extensions
.
observe_value
(
"eps"
,
lambda
trainer
:
trainer
.
updater
.
get_optimizer
(
"main"
).
param_groups
[
0
][
"eps"
],
),
trigger
=
(
args
.
report_interval_iters
,
"iteration"
),
)
report_keys
.
append
(
"eps"
)
elif
args
.
opt
in
[
"adam"
,
"noam"
]:
trainer
.
extend
(
extensions
.
observe_value
(
"lr"
,
lambda
trainer
:
trainer
.
updater
.
get_optimizer
(
"main"
).
param_groups
[
0
][
"lr"
],
),
trigger
=
(
args
.
report_interval_iters
,
"iteration"
),
)
report_keys
.
append
(
"lr"
)
if
args
.
report_bleu
:
report_keys
.
append
(
"main/bleu"
)
report_keys
.
append
(
"validation/main/bleu"
)
trainer
.
extend
(
extensions
.
PrintReport
(
report_keys
),
trigger
=
(
args
.
report_interval_iters
,
"iteration"
),
)
trainer
.
extend
(
extensions
.
ProgressBar
(
update_interval
=
args
.
report_interval_iters
))
set_early_stop
(
trainer
,
args
)
if
args
.
tensorboard_dir
is
not
None
and
args
.
tensorboard_dir
!=
""
:
from
torch.utils.tensorboard
import
SummaryWriter
trainer
.
extend
(
TensorboardLogger
(
SummaryWriter
(
args
.
tensorboard_dir
),
att_reporter
),
trigger
=
(
args
.
report_interval_iters
,
"iteration"
),
)
# Run the training
trainer
.
run
()
check_early_stop
(
trainer
,
args
.
epochs
)
def
trans
(
args
):
"""Decode with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch
(
args
)
model
,
train_args
=
load_trained_model
(
args
.
model
)
assert
isinstance
(
model
,
MTInterface
)
model
.
trans_args
=
args
# gpu
if
args
.
ngpu
==
1
:
gpu_id
=
list
(
range
(
args
.
ngpu
))
logging
.
info
(
"gpu id: "
+
str
(
gpu_id
))
model
.
cuda
()
# read json data
with
open
(
args
.
trans_json
,
"rb"
)
as
f
:
js
=
json
.
load
(
f
)[
"utts"
]
new_js
=
{}
# remove enmpy utterances
if
train_args
.
multilingual
:
js
=
{
k
:
v
for
k
,
v
in
js
.
items
()
if
v
[
"output"
][
0
][
"shape"
][
0
]
>
1
and
v
[
"output"
][
1
][
"shape"
][
0
]
>
1
}
else
:
js
=
{
k
:
v
for
k
,
v
in
js
.
items
()
if
v
[
"output"
][
0
][
"shape"
][
0
]
>
0
and
v
[
"output"
][
1
][
"shape"
][
0
]
>
0
}
if
args
.
batchsize
==
0
:
with
torch
.
no_grad
():
for
idx
,
name
in
enumerate
(
js
.
keys
(),
1
):
logging
.
info
(
"(%d/%d) decoding "
+
name
,
idx
,
len
(
js
.
keys
()))
feat
=
[
js
[
name
][
"output"
][
1
][
"tokenid"
].
split
()]
nbest_hyps
=
model
.
translate
(
feat
,
args
,
train_args
.
char_list
)
new_js
[
name
]
=
add_results_to_json
(
js
[
name
],
nbest_hyps
,
train_args
.
char_list
)
else
:
def
grouper
(
n
,
iterable
,
fillvalue
=
None
):
kargs
=
[
iter
(
iterable
)]
*
n
return
itertools
.
zip_longest
(
*
kargs
,
fillvalue
=
fillvalue
)
# sort data
keys
=
list
(
js
.
keys
())
feat_lens
=
[
js
[
key
][
"output"
][
1
][
"shape"
][
0
]
for
key
in
keys
]
sorted_index
=
sorted
(
range
(
len
(
feat_lens
)),
key
=
lambda
i
:
-
feat_lens
[
i
])
keys
=
[
keys
[
i
]
for
i
in
sorted_index
]
with
torch
.
no_grad
():
for
names
in
grouper
(
args
.
batchsize
,
keys
,
None
):
names
=
[
name
for
name
in
names
if
name
]
feats
=
[
np
.
fromiter
(
map
(
int
,
js
[
name
][
"output"
][
1
][
"tokenid"
].
split
()),
dtype
=
np
.
int64
,
)
for
name
in
names
]
nbest_hyps
=
model
.
translate_batch
(
feats
,
args
,
train_args
.
char_list
,
)
for
i
,
nbest_hyp
in
enumerate
(
nbest_hyps
):
name
=
names
[
i
]
new_js
[
name
]
=
add_results_to_json
(
js
[
name
],
nbest_hyp
,
train_args
.
char_list
)
with
open
(
args
.
result_label
,
"wb"
)
as
f
:
f
.
write
(
json
.
dumps
(
{
"utts"
:
new_js
},
indent
=
4
,
ensure_ascii
=
False
,
sort_keys
=
True
).
encode
(
"utf_8"
)
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/asr_interface.py
0 → 100644
View file @
60a2c57a
"""ASR Interface module."""
import
argparse
from
espnet.bin.asr_train
import
get_parser
from
espnet.utils.dynamic_import
import
dynamic_import
from
espnet.utils.fill_missing_args
import
fill_missing_args
class
ASRInterface
:
"""ASR Interface for ESPnet model implementation."""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments to parser."""
return
parser
@
classmethod
def
build
(
cls
,
idim
:
int
,
odim
:
int
,
**
kwargs
):
"""Initialize this class with python-level args.
Args:
idim (int): The number of an input feature dim.
odim (int): The number of output vocab.
Returns:
ASRinterface: A new instance of ASRInterface.
"""
def
wrap
(
parser
):
return
get_parser
(
parser
,
required
=
False
)
args
=
argparse
.
Namespace
(
**
kwargs
)
args
=
fill_missing_args
(
args
,
wrap
)
args
=
fill_missing_args
(
args
,
cls
.
add_arguments
)
return
cls
(
idim
,
odim
,
args
)
def
forward
(
self
,
xs
,
ilens
,
ys
):
"""Compute loss for training.
:param xs:
For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim)
For chainer, list of source sequences chainer.Variable
:param ilens: batch of lengths of source sequences (B)
For pytorch, torch.Tensor
For chainer, list of int
:param ys:
For pytorch, batch of padded source sequences torch.Tensor (B, Lmax)
For chainer, list of source sequences chainer.Variable
:return: loss value
:rtype: torch.Tensor for pytorch, chainer.Variable for chainer
"""
raise
NotImplementedError
(
"forward method is not implemented"
)
def
recognize
(
self
,
x
,
recog_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace recog_args: argment namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"recognize method is not implemented"
)
def
recognize_batch
(
self
,
x
,
recog_args
,
char_list
=
None
,
rnnlm
=
None
):
"""Beam search implementation for batch.
:param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace recog_args: argument namespace containing options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise
NotImplementedError
(
"Batch decoding is not supported yet."
)
def
calculate_all_attentions
(
self
,
xs
,
ilens
,
ys
):
"""Calculate attention.
:param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: attention weights (B, Lmax, Tmax)
:rtype: float ndarray
"""
raise
NotImplementedError
(
"calculate_all_attentions method is not implemented"
)
def
calculate_all_ctc_probs
(
self
,
xs
,
ilens
,
ys
):
"""Calculate CTC probability.
:param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...]
:param ndarray ilens: batch of lengths of input sequences (B)
:param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...]
:return: CTC probabilities (B, Tmax, vocab)
:rtype: float ndarray
"""
raise
NotImplementedError
(
"calculate_all_ctc_probs method is not implemented"
)
@
property
def
attention_plot_class
(
self
):
"""Get attention plot class."""
from
espnet.asr.asr_utils
import
PlotAttentionReport
return
PlotAttentionReport
@
property
def
ctc_plot_class
(
self
):
"""Get CTC plot class."""
from
espnet.asr.asr_utils
import
PlotCTCReport
return
PlotCTCReport
def
get_total_subsampling_factor
(
self
):
"""Get total subsampling factor."""
raise
NotImplementedError
(
"get_total_subsampling_factor method is not implemented"
)
def
encode
(
self
,
feat
):
"""Encode feature in `beam_search` (optional).
Args:
x (numpy.ndarray): input feature (T, D)
Returns:
torch.Tensor for pytorch, chainer.Variable for chainer:
encoded feature (T, D)
"""
raise
NotImplementedError
(
"encode method is not implemented"
)
def
scorers
(
self
):
"""Get scorers for `beam_search` (optional).
Returns:
dict[str, ScorerInterface]: dict of `ScorerInterface` objects
"""
raise
NotImplementedError
(
"decoders method is not implemented"
)
predefined_asr
=
{
"pytorch"
:
{
"rnn"
:
"espnet.nets.pytorch_backend.e2e_asr:E2E"
,
"transducer"
:
"espnet.nets.pytorch_backend.e2e_asr_transducer:E2E"
,
"transformer"
:
"espnet.nets.pytorch_backend.e2e_asr_transformer:E2E"
,
"conformer"
:
"espnet.nets.pytorch_backend.e2e_asr_conformer:E2E"
,
},
"chainer"
:
{
"rnn"
:
"espnet.nets.chainer_backend.e2e_asr:E2E"
,
"transformer"
:
"espnet.nets.chainer_backend.e2e_asr_transformer:E2E"
,
},
}
def
dynamic_import_asr
(
module
,
backend
):
"""Import ASR models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_asr`
backend (str): NN backend. e.g., pytorch, chainer
Returns:
type: ASR class
"""
model_class
=
dynamic_import
(
module
,
predefined_asr
.
get
(
backend
,
dict
()))
assert
issubclass
(
model_class
,
ASRInterface
),
f
"
{
module
}
does not implement ASRInterface"
return
model_class
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/batch_beam_search.py
0 → 100644
View file @
60a2c57a
"""Parallel beam search module."""
import
logging
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Tuple
import
torch
from
packaging.version
import
parse
as
V
from
torch.nn.utils.rnn
import
pad_sequence
from
espnet.nets.beam_search
import
BeamSearch
,
Hypothesis
is_torch_1_9_plus
=
V
(
torch
.
__version__
)
>=
V
(
"1.9.0"
)
class
BatchHypothesis
(
NamedTuple
):
"""Batchfied/Vectorized hypothesis data type."""
yseq
:
torch
.
Tensor
=
torch
.
tensor
([])
# (batch, maxlen)
score
:
torch
.
Tensor
=
torch
.
tensor
([])
# (batch,)
length
:
torch
.
Tensor
=
torch
.
tensor
([])
# (batch,)
scores
:
Dict
[
str
,
torch
.
Tensor
]
=
dict
()
# values: (batch,)
states
:
Dict
[
str
,
Dict
]
=
dict
()
def
__len__
(
self
)
->
int
:
"""Return a batch size."""
return
len
(
self
.
length
)
class
BatchBeamSearch
(
BeamSearch
):
"""Batch beam search implementation."""
def
batchfy
(
self
,
hyps
:
List
[
Hypothesis
])
->
BatchHypothesis
:
"""Convert list to batch."""
if
len
(
hyps
)
==
0
:
return
BatchHypothesis
()
return
BatchHypothesis
(
yseq
=
pad_sequence
(
[
h
.
yseq
for
h
in
hyps
],
batch_first
=
True
,
padding_value
=
self
.
eos
),
length
=
torch
.
tensor
([
len
(
h
.
yseq
)
for
h
in
hyps
],
dtype
=
torch
.
int64
),
score
=
torch
.
tensor
([
h
.
score
for
h
in
hyps
]),
scores
=
{
k
:
torch
.
tensor
([
h
.
scores
[
k
]
for
h
in
hyps
])
for
k
in
self
.
scorers
},
states
=
{
k
:
[
h
.
states
[
k
]
for
h
in
hyps
]
for
k
in
self
.
scorers
},
)
def
_batch_select
(
self
,
hyps
:
BatchHypothesis
,
ids
:
List
[
int
])
->
BatchHypothesis
:
return
BatchHypothesis
(
yseq
=
hyps
.
yseq
[
ids
],
score
=
hyps
.
score
[
ids
],
length
=
hyps
.
length
[
ids
],
scores
=
{
k
:
v
[
ids
]
for
k
,
v
in
hyps
.
scores
.
items
()},
states
=
{
k
:
[
self
.
scorers
[
k
].
select_state
(
v
,
i
)
for
i
in
ids
]
for
k
,
v
in
hyps
.
states
.
items
()
},
)
def
_select
(
self
,
hyps
:
BatchHypothesis
,
i
:
int
)
->
Hypothesis
:
return
Hypothesis
(
yseq
=
hyps
.
yseq
[
i
,
:
hyps
.
length
[
i
]],
score
=
hyps
.
score
[
i
],
scores
=
{
k
:
v
[
i
]
for
k
,
v
in
hyps
.
scores
.
items
()},
states
=
{
k
:
self
.
scorers
[
k
].
select_state
(
v
,
i
)
for
k
,
v
in
hyps
.
states
.
items
()
},
)
def
unbatchfy
(
self
,
batch_hyps
:
BatchHypothesis
)
->
List
[
Hypothesis
]:
"""Revert batch to list."""
return
[
Hypothesis
(
yseq
=
batch_hyps
.
yseq
[
i
][:
batch_hyps
.
length
[
i
]],
score
=
batch_hyps
.
score
[
i
],
scores
=
{
k
:
batch_hyps
.
scores
[
k
][
i
]
for
k
in
self
.
scorers
},
states
=
{
k
:
v
.
select_state
(
batch_hyps
.
states
[
k
],
i
)
for
k
,
v
in
self
.
scorers
.
items
()
},
)
for
i
in
range
(
len
(
batch_hyps
.
length
))
]
def
batch_beam
(
self
,
weighted_scores
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Batch-compute topk full token ids and partial token ids.
Args:
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
Its shape is `(n_beam, self.vocab_size)`.
ids (torch.Tensor): The partial token ids to compute topk.
Its shape is `(n_beam, self.pre_beam_size)`.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
The topk full (prev_hyp, new_token) ids
and partial (prev_hyp, new_token) ids.
Their shapes are all `(self.beam_size,)`
"""
top_ids
=
weighted_scores
.
view
(
-
1
).
topk
(
self
.
beam_size
)[
1
]
# Because of the flatten above, `top_ids` is organized as:
# [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK],
# where V is `self.n_vocab` and K is `self.beam_size`
if
is_torch_1_9_plus
:
prev_hyp_ids
=
torch
.
div
(
top_ids
,
self
.
n_vocab
,
rounding_mode
=
"trunc"
)
else
:
prev_hyp_ids
=
top_ids
//
self
.
n_vocab
new_token_ids
=
top_ids
%
self
.
n_vocab
return
prev_hyp_ids
,
new_token_ids
,
prev_hyp_ids
,
new_token_ids
def
init_hyp
(
self
,
x
:
torch
.
Tensor
)
->
BatchHypothesis
:
"""Get an initial hypothesis data.
Args:
x (torch.Tensor): The encoder output feature
Returns:
Hypothesis: The initial hypothesis.
"""
init_states
=
dict
()
init_scores
=
dict
()
for
k
,
d
in
self
.
scorers
.
items
():
init_states
[
k
]
=
d
.
batch_init_state
(
x
)
init_scores
[
k
]
=
0.0
# NOTE (Shih-Lun): added for OpenAI Whisper ASR
primer
=
[
self
.
sos
]
if
self
.
hyp_primer
is
None
else
self
.
hyp_primer
return
self
.
batchfy
(
[
Hypothesis
(
score
=
0.0
,
scores
=
init_scores
,
states
=
init_states
,
yseq
=
torch
.
tensor
(
primer
,
device
=
x
.
device
),
)
]
)
def
score_full
(
self
,
hyp
:
BatchHypothesis
,
x
:
torch
.
Tensor
)
->
Tuple
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
str
,
Any
]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.full_scorers`
and tensor score values of shape: `(self.n_vocab,)`,
and state dict that has string keys
and state values of `self.full_scorers`
"""
scores
=
dict
()
states
=
dict
()
for
k
,
d
in
self
.
full_scorers
.
items
():
scores
[
k
],
states
[
k
]
=
d
.
batch_score
(
hyp
.
yseq
,
hyp
.
states
[
k
],
x
)
return
scores
,
states
def
score_partial
(
self
,
hyp
:
BatchHypothesis
,
ids
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
Tuple
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
str
,
Any
]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
ids (torch.Tensor): 2D tensor of new partial tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.full_scorers`
and tensor score values of shape: `(self.n_vocab,)`,
and state dict that has string keys
and state values of `self.full_scorers`
"""
scores
=
dict
()
states
=
dict
()
for
k
,
d
in
self
.
part_scorers
.
items
():
scores
[
k
],
states
[
k
]
=
d
.
batch_score_partial
(
hyp
.
yseq
,
ids
,
hyp
.
states
[
k
],
x
)
return
scores
,
states
def
merge_states
(
self
,
states
:
Any
,
part_states
:
Any
,
part_idx
:
int
)
->
Any
:
"""Merge states for new hypothesis.
Args:
states: states of `self.full_scorers`
part_states: states of `self.part_scorers`
part_idx (int): The new token id for `part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are states of the scorers.
"""
new_states
=
dict
()
for
k
,
v
in
states
.
items
():
new_states
[
k
]
=
v
for
k
,
v
in
part_states
.
items
():
new_states
[
k
]
=
v
return
new_states
def
search
(
self
,
running_hyps
:
BatchHypothesis
,
x
:
torch
.
Tensor
)
->
BatchHypothesis
:
"""Search new tokens for running hypotheses and encoded speech x.
Args:
running_hyps (BatchHypothesis): Running hypotheses on beam
x (torch.Tensor): Encoded speech feature (T, D)
Returns:
BatchHypothesis: Best sorted hypotheses
"""
n_batch
=
len
(
running_hyps
)
part_ids
=
None
# no pre-beam
# batch scoring
weighted_scores
=
torch
.
zeros
(
n_batch
,
self
.
n_vocab
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
scores
,
states
=
self
.
score_full
(
running_hyps
,
x
.
expand
(
n_batch
,
*
x
.
shape
))
for
k
in
self
.
full_scorers
:
weighted_scores
+=
self
.
weights
[
k
]
*
scores
[
k
]
# partial scoring
if
self
.
do_pre_beam
:
pre_beam_scores
=
(
weighted_scores
if
self
.
pre_beam_score_key
==
"full"
else
scores
[
self
.
pre_beam_score_key
]
)
part_ids
=
torch
.
topk
(
pre_beam_scores
,
self
.
pre_beam_size
,
dim
=-
1
)[
1
]
# NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns
# full-size score matrices, which has non-zero scores for part_ids and zeros
# for others.
part_scores
,
part_states
=
self
.
score_partial
(
running_hyps
,
part_ids
,
x
)
for
k
in
self
.
part_scorers
:
weighted_scores
+=
self
.
weights
[
k
]
*
part_scores
[
k
]
# add previous hyp scores
weighted_scores
+=
running_hyps
.
score
.
to
(
dtype
=
x
.
dtype
,
device
=
x
.
device
).
unsqueeze
(
1
)
# TODO(karita): do not use list. use batch instead
# see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029
# update hyps
best_hyps
=
[]
prev_hyps
=
self
.
unbatchfy
(
running_hyps
)
for
(
full_prev_hyp_id
,
full_new_token_id
,
part_prev_hyp_id
,
part_new_token_id
,
)
in
zip
(
*
self
.
batch_beam
(
weighted_scores
,
part_ids
)):
prev_hyp
=
prev_hyps
[
full_prev_hyp_id
]
best_hyps
.
append
(
Hypothesis
(
score
=
weighted_scores
[
full_prev_hyp_id
,
full_new_token_id
],
yseq
=
self
.
append_token
(
prev_hyp
.
yseq
,
full_new_token_id
),
scores
=
self
.
merge_scores
(
prev_hyp
.
scores
,
{
k
:
v
[
full_prev_hyp_id
]
for
k
,
v
in
scores
.
items
()},
full_new_token_id
,
{
k
:
v
[
part_prev_hyp_id
]
for
k
,
v
in
part_scores
.
items
()},
part_new_token_id
,
),
states
=
self
.
merge_states
(
{
k
:
self
.
full_scorers
[
k
].
select_state
(
v
,
full_prev_hyp_id
)
for
k
,
v
in
states
.
items
()
},
{
k
:
self
.
part_scorers
[
k
].
select_state
(
v
,
part_prev_hyp_id
,
part_new_token_id
)
for
k
,
v
in
part_states
.
items
()
},
part_new_token_id
,
),
)
)
return
self
.
batchfy
(
best_hyps
)
def
post_process
(
self
,
i
:
int
,
maxlen
:
int
,
maxlenratio
:
float
,
running_hyps
:
BatchHypothesis
,
ended_hyps
:
List
[
Hypothesis
],
)
->
BatchHypothesis
:
"""Perform post-processing of beam search iterations.
Args:
i (int): The length of hypothesis tokens.
maxlen (int): The maximum length of tokens in beam search.
maxlenratio (int): The maximum length ratio in beam search.
running_hyps (BatchHypothesis): The running hypotheses in beam search.
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
Returns:
BatchHypothesis: The new running hypotheses.
"""
n_batch
=
running_hyps
.
yseq
.
shape
[
0
]
logging
.
debug
(
f
"the number of running hypothes:
{
n_batch
}
"
)
if
self
.
token_list
is
not
None
:
logging
.
debug
(
"best hypo: "
+
""
.
join
(
[
self
.
token_list
[
x
]
for
x
in
running_hyps
.
yseq
[
0
,
1
:
running_hyps
.
length
[
0
]]
]
)
)
# add eos in the final loop to avoid that there are no ended hyps
if
i
==
maxlen
-
1
:
logging
.
info
(
"adding <eos> in the last position in the loop"
)
yseq_eos
=
torch
.
cat
(
(
running_hyps
.
yseq
,
torch
.
full
(
(
n_batch
,
1
),
self
.
eos
,
device
=
running_hyps
.
yseq
.
device
,
dtype
=
torch
.
int64
,
),
),
1
,
)
running_hyps
.
yseq
.
resize_as_
(
yseq_eos
)
running_hyps
.
yseq
[:]
=
yseq_eos
running_hyps
.
length
[:]
=
yseq_eos
.
shape
[
1
]
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a probmlem, number of hyps < beam)
is_eos
=
(
running_hyps
.
yseq
[
torch
.
arange
(
n_batch
),
running_hyps
.
length
-
1
]
==
self
.
eos
)
for
b
in
torch
.
nonzero
(
is_eos
,
as_tuple
=
False
).
view
(
-
1
):
hyp
=
self
.
_select
(
running_hyps
,
b
)
ended_hyps
.
append
(
hyp
)
remained_ids
=
torch
.
nonzero
(
is_eos
==
0
,
as_tuple
=
False
).
view
(
-
1
).
cpu
()
return
self
.
_batch_select
(
running_hyps
,
remained_ids
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/batch_beam_search_online.py
0 → 100644
View file @
60a2c57a
"""Parallel beam search module for online simulation."""
import
logging
from
typing
import
Any
# noqa: H301
from
typing
import
Dict
# noqa: H301
from
typing
import
List
# noqa: H301
from
typing
import
Tuple
# noqa: H301
import
torch
from
espnet.nets.batch_beam_search
import
BatchBeamSearch
# noqa: H301
from
espnet.nets.batch_beam_search
import
BatchHypothesis
# noqa: H301
from
espnet.nets.beam_search
import
Hypothesis
from
espnet.nets.e2e_asr_common
import
end_detect
class
BatchBeamSearchOnline
(
BatchBeamSearch
):
"""Online beam search implementation.
This simulates streaming decoding.
It requires encoded features of entire utterance and
extracts block by block from it as it shoud be done
in streaming processing.
This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
(https://arxiv.org/abs/2006.14941).
"""
def
__init__
(
self
,
*
args
,
block_size
=
40
,
hop_size
=
16
,
look_ahead
=
16
,
disable_repetition_detection
=
False
,
encoded_feat_length_limit
=
0
,
decoder_text_length_limit
=
0
,
**
kwargs
,
):
"""Initialize beam search."""
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
block_size
=
block_size
self
.
hop_size
=
hop_size
self
.
look_ahead
=
look_ahead
self
.
disable_repetition_detection
=
disable_repetition_detection
self
.
encoded_feat_length_limit
=
encoded_feat_length_limit
self
.
decoder_text_length_limit
=
decoder_text_length_limit
self
.
reset
()
def
reset
(
self
):
"""Reset parameters."""
self
.
encbuffer
=
None
self
.
running_hyps
=
None
self
.
prev_hyps
=
[]
self
.
ended_hyps
=
[]
self
.
processed_block
=
0
self
.
process_idx
=
0
self
.
prev_output
=
None
def
score_full
(
self
,
hyp
:
BatchHypothesis
,
x
:
torch
.
Tensor
)
->
Tuple
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
str
,
Any
]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.full_scorers`
and tensor score values of shape: `(self.n_vocab,)`,
and state dict that has string keys
and state values of `self.full_scorers`
"""
scores
=
dict
()
states
=
dict
()
for
k
,
d
in
self
.
full_scorers
.
items
():
if
(
self
.
decoder_text_length_limit
>
0
and
len
(
hyp
.
yseq
)
>
0
and
len
(
hyp
.
yseq
[
0
])
>
self
.
decoder_text_length_limit
):
temp_yseq
=
hyp
.
yseq
.
narrow
(
1
,
-
self
.
decoder_text_length_limit
,
self
.
decoder_text_length_limit
).
clone
()
temp_yseq
[:,
0
]
=
self
.
sos
self
.
running_hyps
.
states
[
"decoder"
]
=
[
None
for
_
in
self
.
running_hyps
.
states
[
"decoder"
]
]
scores
[
k
],
states
[
k
]
=
d
.
batch_score
(
temp_yseq
,
hyp
.
states
[
k
],
x
)
else
:
scores
[
k
],
states
[
k
]
=
d
.
batch_score
(
hyp
.
yseq
,
hyp
.
states
[
k
],
x
)
return
scores
,
states
def
forward
(
self
,
x
:
torch
.
Tensor
,
maxlenratio
:
float
=
0.0
,
minlenratio
:
float
=
0.0
,
is_final
:
bool
=
True
,
)
->
List
[
Hypothesis
]:
"""Perform beam search.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
Returns:
list[Hypothesis]: N-best decoding results
"""
if
self
.
encbuffer
is
None
:
self
.
encbuffer
=
x
else
:
self
.
encbuffer
=
torch
.
cat
([
self
.
encbuffer
,
x
],
axis
=
0
)
x
=
self
.
encbuffer
# set length bounds
if
maxlenratio
==
0
:
maxlen
=
x
.
shape
[
0
]
else
:
maxlen
=
max
(
1
,
int
(
maxlenratio
*
x
.
size
(
0
)))
ret
=
None
while
True
:
cur_end_frame
=
(
self
.
block_size
-
self
.
look_ahead
+
self
.
hop_size
*
self
.
processed_block
)
if
cur_end_frame
<
x
.
shape
[
0
]:
h
=
x
.
narrow
(
0
,
0
,
cur_end_frame
)
block_is_final
=
False
else
:
if
is_final
:
h
=
x
block_is_final
=
True
else
:
break
logging
.
debug
(
"Start processing block: %d"
,
self
.
processed_block
)
logging
.
debug
(
" Feature length: {}, current position: {}"
.
format
(
h
.
shape
[
0
],
self
.
process_idx
)
)
if
(
self
.
encoded_feat_length_limit
>
0
and
h
.
shape
[
0
]
>
self
.
encoded_feat_length_limit
):
h
=
h
.
narrow
(
0
,
h
.
shape
[
0
]
-
self
.
encoded_feat_length_limit
,
self
.
encoded_feat_length_limit
,
)
if
self
.
running_hyps
is
None
:
self
.
running_hyps
=
self
.
init_hyp
(
h
)
ret
=
self
.
process_one_block
(
h
,
block_is_final
,
maxlen
,
maxlenratio
)
logging
.
debug
(
"Finished processing block: %d"
,
self
.
processed_block
)
self
.
processed_block
+=
1
if
block_is_final
:
return
ret
if
ret
is
None
:
if
self
.
prev_output
is
None
:
return
[]
else
:
return
self
.
prev_output
else
:
self
.
prev_output
=
ret
# N-best results
return
ret
def
process_one_block
(
self
,
h
,
is_final
,
maxlen
,
maxlenratio
):
"""Recognize one block."""
# extend states for ctc
self
.
extend
(
h
,
self
.
running_hyps
)
while
self
.
process_idx
<
maxlen
:
logging
.
debug
(
"position "
+
str
(
self
.
process_idx
))
best
=
self
.
search
(
self
.
running_hyps
,
h
)
if
self
.
process_idx
==
maxlen
-
1
:
# end decoding
self
.
running_hyps
=
self
.
post_process
(
self
.
process_idx
,
maxlen
,
maxlenratio
,
best
,
self
.
ended_hyps
)
n_batch
=
best
.
yseq
.
shape
[
0
]
local_ended_hyps
=
[]
is_local_eos
=
best
.
yseq
[
torch
.
arange
(
n_batch
),
best
.
length
-
1
]
==
self
.
eos
prev_repeat
=
False
for
i
in
range
(
is_local_eos
.
shape
[
0
]):
if
is_local_eos
[
i
]:
hyp
=
self
.
_select
(
best
,
i
)
local_ended_hyps
.
append
(
hyp
)
# NOTE(tsunoo): check repetitions here
# This is a implicit implementation of
# Eq (11) in https://arxiv.org/abs/2006.14941
# A flag prev_repeat is used instead of using set
# NOTE(fujihara): I made it possible to turned off
# the below lines using disable_repetition_detection flag,
# because this criteria is too sensitive that the beam
# search starts only after the entire inputs are available.
# Empirically, this flag didn't affect the performance.
elif
(
not
self
.
disable_repetition_detection
and
not
prev_repeat
and
best
.
yseq
[
i
,
-
1
]
in
best
.
yseq
[
i
,
:
-
1
]
and
not
is_final
):
prev_repeat
=
True
if
prev_repeat
:
logging
.
info
(
"Detected repetition."
)
break
if
(
is_final
and
maxlenratio
==
0.0
and
end_detect
(
[
lh
.
asdict
()
for
lh
in
self
.
ended_hyps
],
self
.
process_idx
)
):
logging
.
info
(
f
"end detected at
{
self
.
process_idx
}
"
)
return
self
.
assemble_hyps
(
self
.
ended_hyps
)
if
len
(
local_ended_hyps
)
>
0
and
not
is_final
:
logging
.
info
(
"Detected hyp(s) reaching EOS in this block."
)
break
self
.
prev_hyps
=
self
.
running_hyps
self
.
running_hyps
=
self
.
post_process
(
self
.
process_idx
,
maxlen
,
maxlenratio
,
best
,
self
.
ended_hyps
)
if
is_final
:
for
hyp
in
local_ended_hyps
:
self
.
ended_hyps
.
append
(
hyp
)
if
len
(
self
.
running_hyps
)
==
0
:
logging
.
info
(
"no hypothesis. Finish decoding."
)
return
self
.
assemble_hyps
(
self
.
ended_hyps
)
else
:
logging
.
debug
(
f
"remained hypotheses:
{
len
(
self
.
running_hyps
)
}
"
)
# increment number
self
.
process_idx
+=
1
if
is_final
:
return
self
.
assemble_hyps
(
self
.
ended_hyps
)
else
:
for
hyp
in
self
.
ended_hyps
:
local_ended_hyps
.
append
(
hyp
)
rets
=
self
.
assemble_hyps
(
local_ended_hyps
)
if
self
.
process_idx
>
1
and
len
(
self
.
prev_hyps
)
>
0
:
self
.
running_hyps
=
self
.
prev_hyps
self
.
process_idx
-=
1
self
.
prev_hyps
=
[]
# N-best results
return
rets
def
assemble_hyps
(
self
,
ended_hyps
):
"""Assemble the hypotheses."""
nbest_hyps
=
sorted
(
ended_hyps
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)
# check the number of hypotheses reaching to eos
if
len
(
nbest_hyps
)
==
0
:
logging
.
warning
(
"there is no N-best results, perform recognition "
"again with smaller minlenratio."
)
return
[]
# report the best result
best
=
nbest_hyps
[
0
]
for
k
,
v
in
best
.
scores
.
items
():
logging
.
info
(
f
"
{
v
:
6.2
f
}
*
{
self
.
weights
[
k
]:
3
}
=
{
v
*
self
.
weights
[
k
]:
6.2
f
}
for
{
k
}
"
)
logging
.
info
(
f
"total log probability:
{
best
.
score
:.
2
f
}
"
)
logging
.
info
(
f
"normalized log probability:
{
best
.
score
/
len
(
best
.
yseq
):.
2
f
}
"
)
logging
.
info
(
f
"total number of ended hypotheses:
{
len
(
nbest_hyps
)
}
"
)
if
self
.
token_list
is
not
None
:
logging
.
info
(
"best hypo: "
+
""
.
join
([
self
.
token_list
[
x
]
for
x
in
best
.
yseq
[
1
:
-
1
]])
+
"
\n
"
)
return
nbest_hyps
def
extend
(
self
,
x
:
torch
.
Tensor
,
hyps
:
Hypothesis
)
->
List
[
Hypothesis
]:
"""Extend probabilities and states with more encoded chunks.
Args:
x (torch.Tensor): The extended encoder output feature
hyps (Hypothesis): Current list of hypothesis
Returns:
Hypothesis: The extended hypothesis
"""
for
k
,
d
in
self
.
scorers
.
items
():
if
hasattr
(
d
,
"extend_prob"
):
d
.
extend_prob
(
x
)
if
hasattr
(
d
,
"extend_state"
):
hyps
.
states
[
k
]
=
d
.
extend_state
(
hyps
.
states
[
k
])
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/batch_beam_search_online_sim.py
0 → 100644
View file @
60a2c57a
"""Parallel beam search module for online simulation."""
import
logging
from
pathlib
import
Path
from
typing
import
List
import
torch
import
yaml
from
espnet.nets.batch_beam_search
import
BatchBeamSearch
from
espnet.nets.beam_search
import
Hypothesis
from
espnet.nets.e2e_asr_common
import
end_detect
class
BatchBeamSearchOnlineSim
(
BatchBeamSearch
):
"""Online beam search implementation.
This simulates streaming decoding.
It requires encoded features of entire utterance and
extracts block by block from it as it shoud be done
in streaming processing.
This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR
WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH"
(https://arxiv.org/abs/2006.14941).
"""
def
set_streaming_config
(
self
,
asr_config
:
str
):
"""Set config file for streaming decoding.
Args:
asr_config (str): The config file for asr training
"""
train_config_file
=
Path
(
asr_config
)
self
.
block_size
=
None
self
.
hop_size
=
None
self
.
look_ahead
=
None
config
=
None
with
train_config_file
.
open
(
"r"
,
encoding
=
"utf-8"
)
as
f
:
args
=
yaml
.
safe_load
(
f
)
if
"encoder_conf"
in
args
.
keys
():
if
"block_size"
in
args
[
"encoder_conf"
].
keys
():
self
.
block_size
=
args
[
"encoder_conf"
][
"block_size"
]
if
"hop_size"
in
args
[
"encoder_conf"
].
keys
():
self
.
hop_size
=
args
[
"encoder_conf"
][
"hop_size"
]
if
"look_ahead"
in
args
[
"encoder_conf"
].
keys
():
self
.
look_ahead
=
args
[
"encoder_conf"
][
"look_ahead"
]
elif
"config"
in
args
.
keys
():
config
=
args
[
"config"
]
if
config
is
None
:
logging
.
info
(
"Cannot find config file for streaming decoding: "
+
"apply batch beam search instead."
)
return
if
(
self
.
block_size
is
None
or
self
.
hop_size
is
None
or
self
.
look_ahead
is
None
)
and
config
is
not
None
:
config_file
=
Path
(
config
)
with
config_file
.
open
(
"r"
,
encoding
=
"utf-8"
)
as
f
:
args
=
yaml
.
safe_load
(
f
)
if
"encoder_conf"
in
args
.
keys
():
enc_args
=
args
[
"encoder_conf"
]
if
enc_args
and
"block_size"
in
enc_args
:
self
.
block_size
=
enc_args
[
"block_size"
]
if
enc_args
and
"hop_size"
in
enc_args
:
self
.
hop_size
=
enc_args
[
"hop_size"
]
if
enc_args
and
"look_ahead"
in
enc_args
:
self
.
look_ahead
=
enc_args
[
"look_ahead"
]
def
set_block_size
(
self
,
block_size
:
int
):
"""Set block size for streaming decoding.
Args:
block_size (int): The block size of encoder
"""
self
.
block_size
=
block_size
def
set_hop_size
(
self
,
hop_size
:
int
):
"""Set hop size for streaming decoding.
Args:
hop_size (int): The hop size of encoder
"""
self
.
hop_size
=
hop_size
def
set_look_ahead
(
self
,
look_ahead
:
int
):
"""Set look ahead size for streaming decoding.
Args:
look_ahead (int): The look ahead size of encoder
"""
self
.
look_ahead
=
look_ahead
def
forward
(
self
,
x
:
torch
.
Tensor
,
maxlenratio
:
float
=
0.0
,
minlenratio
:
float
=
0.0
)
->
List
[
Hypothesis
]:
"""Perform beam search.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
Returns:
list[Hypothesis]: N-best decoding results
"""
self
.
conservative
=
True
# always true
if
self
.
block_size
and
self
.
hop_size
and
self
.
look_ahead
:
cur_end_frame
=
int
(
self
.
block_size
-
self
.
look_ahead
)
else
:
cur_end_frame
=
x
.
shape
[
0
]
process_idx
=
0
if
cur_end_frame
<
x
.
shape
[
0
]:
h
=
x
.
narrow
(
0
,
0
,
cur_end_frame
)
else
:
h
=
x
# set length bounds
if
maxlenratio
==
0
:
maxlen
=
x
.
shape
[
0
]
else
:
maxlen
=
max
(
1
,
int
(
maxlenratio
*
x
.
size
(
0
)))
minlen
=
int
(
minlenratio
*
x
.
size
(
0
))
logging
.
info
(
"decoder input length: "
+
str
(
x
.
shape
[
0
]))
logging
.
info
(
"max output length: "
+
str
(
maxlen
))
logging
.
info
(
"min output length: "
+
str
(
minlen
))
# main loop of prefix search
running_hyps
=
self
.
init_hyp
(
h
)
prev_hyps
=
[]
ended_hyps
=
[]
prev_repeat
=
False
continue_decode
=
True
while
continue_decode
:
move_to_next_block
=
False
if
cur_end_frame
<
x
.
shape
[
0
]:
h
=
x
.
narrow
(
0
,
0
,
cur_end_frame
)
else
:
h
=
x
# extend states for ctc
self
.
extend
(
h
,
running_hyps
)
while
process_idx
<
maxlen
:
logging
.
debug
(
"position "
+
str
(
process_idx
))
best
=
self
.
search
(
running_hyps
,
h
)
if
process_idx
==
maxlen
-
1
:
# end decoding
running_hyps
=
self
.
post_process
(
process_idx
,
maxlen
,
maxlenratio
,
best
,
ended_hyps
)
n_batch
=
best
.
yseq
.
shape
[
0
]
local_ended_hyps
=
[]
is_local_eos
=
(
best
.
yseq
[
torch
.
arange
(
n_batch
),
best
.
length
-
1
]
==
self
.
eos
)
for
i
in
range
(
is_local_eos
.
shape
[
0
]):
if
is_local_eos
[
i
]:
hyp
=
self
.
_select
(
best
,
i
)
local_ended_hyps
.
append
(
hyp
)
# NOTE(tsunoo): check repetitions here
# This is a implicit implementation of
# Eq (11) in https://arxiv.org/abs/2006.14941
# A flag prev_repeat is used instead of using set
elif
(
not
prev_repeat
and
best
.
yseq
[
i
,
-
1
]
in
best
.
yseq
[
i
,
:
-
1
]
and
cur_end_frame
<
x
.
shape
[
0
]
):
move_to_next_block
=
True
prev_repeat
=
True
if
maxlenratio
==
0.0
and
end_detect
(
[
lh
.
asdict
()
for
lh
in
local_ended_hyps
],
process_idx
):
logging
.
info
(
f
"end detected at
{
process_idx
}
"
)
continue_decode
=
False
break
if
len
(
local_ended_hyps
)
>
0
and
cur_end_frame
<
x
.
shape
[
0
]:
move_to_next_block
=
True
if
move_to_next_block
:
if
(
self
.
hop_size
and
cur_end_frame
+
int
(
self
.
hop_size
)
+
int
(
self
.
look_ahead
)
<
x
.
shape
[
0
]
):
cur_end_frame
+=
int
(
self
.
hop_size
)
else
:
cur_end_frame
=
x
.
shape
[
0
]
logging
.
debug
(
"Going to next block: %d"
,
cur_end_frame
)
if
process_idx
>
1
and
len
(
prev_hyps
)
>
0
and
self
.
conservative
:
running_hyps
=
prev_hyps
process_idx
-=
1
prev_hyps
=
[]
break
prev_repeat
=
False
prev_hyps
=
running_hyps
running_hyps
=
self
.
post_process
(
process_idx
,
maxlen
,
maxlenratio
,
best
,
ended_hyps
)
if
cur_end_frame
>=
x
.
shape
[
0
]:
for
hyp
in
local_ended_hyps
:
ended_hyps
.
append
(
hyp
)
if
len
(
running_hyps
)
==
0
:
logging
.
info
(
"no hypothesis. Finish decoding."
)
continue_decode
=
False
break
else
:
logging
.
debug
(
f
"remained hypotheses:
{
len
(
running_hyps
)
}
"
)
# increment number
process_idx
+=
1
nbest_hyps
=
sorted
(
ended_hyps
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)
# check the number of hypotheses reaching to eos
if
len
(
nbest_hyps
)
==
0
:
logging
.
warning
(
"there is no N-best results, perform recognition "
"again with smaller minlenratio."
)
return
(
[]
if
minlenratio
<
0.1
else
self
.
forward
(
x
,
maxlenratio
,
max
(
0.0
,
minlenratio
-
0.1
))
)
# report the best result
best
=
nbest_hyps
[
0
]
for
k
,
v
in
best
.
scores
.
items
():
logging
.
info
(
f
"
{
v
:
6.2
f
}
*
{
self
.
weights
[
k
]:
3
}
=
{
v
*
self
.
weights
[
k
]:
6.2
f
}
for
{
k
}
"
)
logging
.
info
(
f
"total log probability:
{
best
.
score
:.
2
f
}
"
)
logging
.
info
(
f
"normalized log probability:
{
best
.
score
/
len
(
best
.
yseq
):.
2
f
}
"
)
logging
.
info
(
f
"total number of ended hypotheses:
{
len
(
nbest_hyps
)
}
"
)
if
self
.
token_list
is
not
None
:
logging
.
info
(
"best hypo: "
+
""
.
join
([
self
.
token_list
[
x
]
for
x
in
best
.
yseq
[
1
:
-
1
]])
+
"
\n
"
)
if
best
.
yseq
[
1
:
-
1
].
shape
[
0
]
==
x
.
shape
[
0
]:
logging
.
warning
(
"best hypo length: {} == max output length: {}"
.
format
(
best
.
yseq
[
1
:
-
1
].
shape
[
0
],
maxlen
)
)
logging
.
warning
(
"decoding may be stopped by the max output length limitation, "
+
"please consider to increase the maxlenratio."
)
return
nbest_hyps
def
extend
(
self
,
x
:
torch
.
Tensor
,
hyps
:
Hypothesis
)
->
List
[
Hypothesis
]:
"""Extend probabilities and states with more encoded chunks.
Args:
x (torch.Tensor): The extended encoder output feature
hyps (Hypothesis): Current list of hypothesis
Returns:
Hypothesis: The extended hypothesis
"""
for
k
,
d
in
self
.
scorers
.
items
():
if
hasattr
(
d
,
"extend_prob"
):
d
.
extend_prob
(
x
)
if
hasattr
(
d
,
"extend_state"
):
hyps
.
states
[
k
]
=
d
.
extend_state
(
hyps
.
states
[
k
])
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/beam_search.py
0 → 100644
View file @
60a2c57a
"""Beam search module."""
import
logging
from
itertools
import
chain
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Tuple
,
Union
import
torch
from
espnet.nets.e2e_asr_common
import
end_detect
from
espnet.nets.scorer_interface
import
PartialScorerInterface
,
ScorerInterface
class
Hypothesis
(
NamedTuple
):
"""Hypothesis data type."""
yseq
:
torch
.
Tensor
score
:
Union
[
float
,
torch
.
Tensor
]
=
0
scores
:
Dict
[
str
,
Union
[
float
,
torch
.
Tensor
]]
=
dict
()
states
:
Dict
[
str
,
Any
]
=
dict
()
def
asdict
(
self
)
->
dict
:
"""Convert data to JSON-friendly dict."""
return
self
.
_replace
(
yseq
=
self
.
yseq
.
tolist
(),
score
=
float
(
self
.
score
),
scores
=
{
k
:
float
(
v
)
for
k
,
v
in
self
.
scores
.
items
()},
).
_asdict
()
class
BeamSearch
(
torch
.
nn
.
Module
):
"""Beam search implementation."""
def
__init__
(
self
,
scorers
:
Dict
[
str
,
ScorerInterface
],
weights
:
Dict
[
str
,
float
],
beam_size
:
int
,
vocab_size
:
int
,
sos
:
int
,
eos
:
int
,
token_list
:
List
[
str
]
=
None
,
pre_beam_ratio
:
float
=
1.5
,
pre_beam_score_key
:
str
=
None
,
hyp_primer
:
List
[
int
]
=
None
,
):
"""Initialize beam search.
Args:
scorers (dict[str, ScorerInterface]): Dict of decoder modules
e.g., Decoder, CTCPrefixScorer, LM
The scorer will be ignored if it is `None`
weights (dict[str, float]): Dict of weights for each scorers
The scorer will be ignored if its weight is 0
beam_size (int): The number of hypotheses kept during search
vocab_size (int): The number of vocabulary
sos (int): Start of sequence id
eos (int): End of sequence id
token_list (list[str]): List of tokens for debug log
pre_beam_score_key (str): key of scores to perform pre-beam search
pre_beam_ratio (float): beam size in the pre-beam search
will be `int(pre_beam_ratio * beam_size)`
"""
super
().
__init__
()
# set scorers
self
.
weights
=
weights
self
.
scorers
=
dict
()
self
.
full_scorers
=
dict
()
self
.
part_scorers
=
dict
()
# this module dict is required for recursive cast
# `self.to(device, dtype)` in `recog.py`
self
.
nn_dict
=
torch
.
nn
.
ModuleDict
()
for
k
,
v
in
scorers
.
items
():
w
=
weights
.
get
(
k
,
0
)
if
w
==
0
or
v
is
None
:
continue
assert
isinstance
(
v
,
ScorerInterface
),
f
"
{
k
}
(
{
type
(
v
)
}
) does not implement ScorerInterface"
self
.
scorers
[
k
]
=
v
if
isinstance
(
v
,
PartialScorerInterface
):
self
.
part_scorers
[
k
]
=
v
else
:
self
.
full_scorers
[
k
]
=
v
if
isinstance
(
v
,
torch
.
nn
.
Module
):
self
.
nn_dict
[
k
]
=
v
# set configurations
self
.
sos
=
sos
self
.
eos
=
eos
# added for OpenAI Whisper decoding
self
.
hyp_primer
=
hyp_primer
self
.
token_list
=
token_list
self
.
pre_beam_size
=
int
(
pre_beam_ratio
*
beam_size
)
self
.
beam_size
=
beam_size
self
.
n_vocab
=
vocab_size
if
(
pre_beam_score_key
is
not
None
and
pre_beam_score_key
!=
"full"
and
pre_beam_score_key
not
in
self
.
full_scorers
):
raise
KeyError
(
f
"
{
pre_beam_score_key
}
is not found in
{
self
.
full_scorers
}
"
)
self
.
pre_beam_score_key
=
pre_beam_score_key
self
.
do_pre_beam
=
(
self
.
pre_beam_score_key
is
not
None
and
self
.
pre_beam_size
<
self
.
n_vocab
and
len
(
self
.
part_scorers
)
>
0
)
def
set_hyp_primer
(
self
,
hyp_primer
:
List
[
int
]
=
None
)
->
None
:
"""Set the primer sequence for decoding.
Used for OpenAI Whisper models.
"""
self
.
hyp_primer
=
hyp_primer
def
init_hyp
(
self
,
x
:
torch
.
Tensor
)
->
List
[
Hypothesis
]:
"""Get an initial hypothesis data.
Args:
x (torch.Tensor): The encoder output feature
Returns:
Hypothesis: The initial hypothesis.
"""
init_states
=
dict
()
init_scores
=
dict
()
for
k
,
d
in
self
.
scorers
.
items
():
init_states
[
k
]
=
d
.
init_state
(
x
)
init_scores
[
k
]
=
0.0
# NOTE (Shih-Lun): added for OpenAI Whisper ASR
primer
=
[
self
.
sos
]
if
self
.
hyp_primer
is
None
else
self
.
hyp_primer
return
[
Hypothesis
(
score
=
0.0
,
scores
=
init_scores
,
states
=
init_states
,
yseq
=
torch
.
tensor
(
primer
,
device
=
x
.
device
),
)
]
@
staticmethod
def
append_token
(
xs
:
torch
.
Tensor
,
x
:
int
)
->
torch
.
Tensor
:
"""Append new token to prefix tokens.
Args:
xs (torch.Tensor): The prefix token
x (int): The new token to append
Returns:
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
"""
x
=
torch
.
tensor
([
x
],
dtype
=
xs
.
dtype
,
device
=
xs
.
device
)
return
torch
.
cat
((
xs
,
x
))
def
score_full
(
self
,
hyp
:
Hypothesis
,
x
:
torch
.
Tensor
)
->
Tuple
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
str
,
Any
]]:
"""Score new hypothesis by `self.full_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.full_scorers`
and tensor score values of shape: `(self.n_vocab,)`,
and state dict that has string keys
and state values of `self.full_scorers`
"""
scores
=
dict
()
states
=
dict
()
for
k
,
d
in
self
.
full_scorers
.
items
():
scores
[
k
],
states
[
k
]
=
d
.
score
(
hyp
.
yseq
,
hyp
.
states
[
k
],
x
)
return
scores
,
states
def
score_partial
(
self
,
hyp
:
Hypothesis
,
ids
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
Tuple
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
str
,
Any
]]:
"""Score new hypothesis by `self.part_scorers`.
Args:
hyp (Hypothesis): Hypothesis with prefix tokens to score
ids (torch.Tensor): 1D tensor of new partial tokens to score
x (torch.Tensor): Corresponding input feature
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
score dict of `hyp` that has string keys of `self.part_scorers`
and tensor score values of shape: `(len(ids),)`,
and state dict that has string keys
and state values of `self.part_scorers`
"""
scores
=
dict
()
states
=
dict
()
for
k
,
d
in
self
.
part_scorers
.
items
():
scores
[
k
],
states
[
k
]
=
d
.
score_partial
(
hyp
.
yseq
,
ids
,
hyp
.
states
[
k
],
x
)
return
scores
,
states
def
beam
(
self
,
weighted_scores
:
torch
.
Tensor
,
ids
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute topk full token ids and partial token ids.
Args:
weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
Its shape is `(self.n_vocab,)`.
ids (torch.Tensor): The partial token ids to compute topk
Returns:
Tuple[torch.Tensor, torch.Tensor]:
The topk full token ids and partial token ids.
Their shapes are `(self.beam_size,)`
"""
# no pre beam performed
if
weighted_scores
.
size
(
0
)
==
ids
.
size
(
0
):
top_ids
=
weighted_scores
.
topk
(
self
.
beam_size
)[
1
]
return
top_ids
,
top_ids
# mask pruned in pre-beam not to select in topk
tmp
=
weighted_scores
[
ids
]
weighted_scores
[:]
=
-
float
(
"inf"
)
weighted_scores
[
ids
]
=
tmp
top_ids
=
weighted_scores
.
topk
(
self
.
beam_size
)[
1
]
local_ids
=
weighted_scores
[
ids
].
topk
(
self
.
beam_size
)[
1
]
return
top_ids
,
local_ids
@
staticmethod
def
merge_scores
(
prev_scores
:
Dict
[
str
,
float
],
next_full_scores
:
Dict
[
str
,
torch
.
Tensor
],
full_idx
:
int
,
next_part_scores
:
Dict
[
str
,
torch
.
Tensor
],
part_idx
:
int
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Merge scores for new hypothesis.
Args:
prev_scores (Dict[str, float]):
The previous hypothesis scores by `self.scorers`
next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
full_idx (int): The next token id for `next_full_scores`
next_part_scores (Dict[str, torch.Tensor]):
scores of partial tokens by `self.part_scorers`
part_idx (int): The new token id for `next_part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are scalar tensors by the scorers.
"""
new_scores
=
dict
()
for
k
,
v
in
next_full_scores
.
items
():
new_scores
[
k
]
=
prev_scores
[
k
]
+
v
[
full_idx
]
for
k
,
v
in
next_part_scores
.
items
():
new_scores
[
k
]
=
prev_scores
[
k
]
+
v
[
part_idx
]
return
new_scores
def
merge_states
(
self
,
states
:
Any
,
part_states
:
Any
,
part_idx
:
int
)
->
Any
:
"""Merge states for new hypothesis.
Args:
states: states of `self.full_scorers`
part_states: states of `self.part_scorers`
part_idx (int): The new token id for `part_scores`
Returns:
Dict[str, torch.Tensor]: The new score dict.
Its keys are names of `self.full_scorers` and `self.part_scorers`.
Its values are states of the scorers.
"""
new_states
=
dict
()
for
k
,
v
in
states
.
items
():
new_states
[
k
]
=
v
for
k
,
d
in
self
.
part_scorers
.
items
():
new_states
[
k
]
=
d
.
select_state
(
part_states
[
k
],
part_idx
)
return
new_states
def
search
(
self
,
running_hyps
:
List
[
Hypothesis
],
x
:
torch
.
Tensor
)
->
List
[
Hypothesis
]:
"""Search new tokens for running hypotheses and encoded speech x.
Args:
running_hyps (List[Hypothesis]): Running hypotheses on beam
x (torch.Tensor): Encoded speech feature (T, D)
Returns:
List[Hypotheses]: Best sorted hypotheses
"""
best_hyps
=
[]
part_ids
=
torch
.
arange
(
self
.
n_vocab
,
device
=
x
.
device
)
# no pre-beam
for
hyp
in
running_hyps
:
# scoring
weighted_scores
=
torch
.
zeros
(
self
.
n_vocab
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
scores
,
states
=
self
.
score_full
(
hyp
,
x
)
for
k
in
self
.
full_scorers
:
weighted_scores
+=
self
.
weights
[
k
]
*
scores
[
k
]
# partial scoring
if
self
.
do_pre_beam
:
pre_beam_scores
=
(
weighted_scores
if
self
.
pre_beam_score_key
==
"full"
else
scores
[
self
.
pre_beam_score_key
]
)
part_ids
=
torch
.
topk
(
pre_beam_scores
,
self
.
pre_beam_size
)[
1
]
part_scores
,
part_states
=
self
.
score_partial
(
hyp
,
part_ids
,
x
)
for
k
in
self
.
part_scorers
:
weighted_scores
[
part_ids
]
+=
self
.
weights
[
k
]
*
part_scores
[
k
]
# add previous hyp score
weighted_scores
+=
hyp
.
score
# update hyps
for
j
,
part_j
in
zip
(
*
self
.
beam
(
weighted_scores
,
part_ids
)):
# will be (2 x beam at most)
best_hyps
.
append
(
Hypothesis
(
score
=
weighted_scores
[
j
],
yseq
=
self
.
append_token
(
hyp
.
yseq
,
j
),
scores
=
self
.
merge_scores
(
hyp
.
scores
,
scores
,
j
,
part_scores
,
part_j
),
states
=
self
.
merge_states
(
states
,
part_states
,
part_j
),
)
)
# sort and prune 2 x beam -> beam
best_hyps
=
sorted
(
best_hyps
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)[
:
min
(
len
(
best_hyps
),
self
.
beam_size
)
]
return
best_hyps
def
forward
(
self
,
x
:
torch
.
Tensor
,
maxlenratio
:
float
=
0.0
,
minlenratio
:
float
=
0.0
)
->
List
[
Hypothesis
]:
"""Perform beam search.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
If maxlenratio<0.0, its absolute value is interpreted
as a constant max output length.
minlenratio (float): Input length ratio to obtain min output length.
If minlenratio<0.0, its absolute value is interpreted
as a constant min output length.
Returns:
list[Hypothesis]: N-best decoding results
"""
# set length bounds
if
maxlenratio
==
0
:
maxlen
=
x
.
shape
[
0
]
elif
maxlenratio
<
0
:
maxlen
=
-
1
*
int
(
maxlenratio
)
else
:
maxlen
=
max
(
1
,
int
(
maxlenratio
*
x
.
size
(
0
)))
if
minlenratio
<
0
:
minlen
=
-
1
*
int
(
minlenratio
)
else
:
minlen
=
int
(
minlenratio
*
x
.
size
(
0
))
logging
.
info
(
"decoder input length: "
+
str
(
x
.
shape
[
0
]))
logging
.
info
(
"max output length: "
+
str
(
maxlen
))
logging
.
info
(
"min output length: "
+
str
(
minlen
))
# main loop of prefix search
running_hyps
=
self
.
init_hyp
(
x
)
ended_hyps
=
[]
for
i
in
range
(
maxlen
):
logging
.
debug
(
"position "
+
str
(
i
))
best
=
self
.
search
(
running_hyps
,
x
)
# post process of one iteration
running_hyps
=
self
.
post_process
(
i
,
maxlen
,
maxlenratio
,
best
,
ended_hyps
)
# end detection
if
maxlenratio
==
0.0
and
end_detect
([
h
.
asdict
()
for
h
in
ended_hyps
],
i
):
logging
.
info
(
f
"end detected at
{
i
}
"
)
break
if
len
(
running_hyps
)
==
0
:
logging
.
info
(
"no hypothesis. Finish decoding."
)
break
else
:
logging
.
debug
(
f
"remained hypotheses:
{
len
(
running_hyps
)
}
"
)
nbest_hyps
=
sorted
(
ended_hyps
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)
# check the number of hypotheses reaching to eos
if
len
(
nbest_hyps
)
==
0
:
logging
.
warning
(
"there is no N-best results, perform recognition "
"again with smaller minlenratio."
)
return
(
[]
if
minlenratio
<
0.1
else
self
.
forward
(
x
,
maxlenratio
,
max
(
0.0
,
minlenratio
-
0.1
))
)
# report the best result
best
=
nbest_hyps
[
0
]
for
k
,
v
in
best
.
scores
.
items
():
logging
.
info
(
f
"
{
v
:
6.2
f
}
*
{
self
.
weights
[
k
]:
3
}
=
{
v
*
self
.
weights
[
k
]:
6.2
f
}
for
{
k
}
"
)
logging
.
info
(
f
"total log probability:
{
best
.
score
:.
2
f
}
"
)
logging
.
info
(
f
"normalized log probability:
{
best
.
score
/
len
(
best
.
yseq
):.
2
f
}
"
)
logging
.
info
(
f
"total number of ended hypotheses:
{
len
(
nbest_hyps
)
}
"
)
if
self
.
token_list
is
not
None
:
logging
.
info
(
"best hypo: "
+
""
.
join
([
self
.
token_list
[
x
]
for
x
in
best
.
yseq
[
1
:
-
1
]])
+
"
\n
"
)
if
best
.
yseq
[
1
:
-
1
].
shape
[
0
]
==
maxlen
:
logging
.
warning
(
"best hypo length: {} == max output length: {}"
.
format
(
best
.
yseq
[
1
:
-
1
].
shape
[
0
],
maxlen
)
)
logging
.
warning
(
"decoding may be stopped by the max output length limitation, "
+
"please consider to increase the maxlenratio."
)
return
nbest_hyps
def
post_process
(
self
,
i
:
int
,
maxlen
:
int
,
maxlenratio
:
float
,
running_hyps
:
List
[
Hypothesis
],
ended_hyps
:
List
[
Hypothesis
],
)
->
List
[
Hypothesis
]:
"""Perform post-processing of beam search iterations.
Args:
i (int): The length of hypothesis tokens.
maxlen (int): The maximum length of tokens in beam search.
maxlenratio (int): The maximum length ratio in beam search.
running_hyps (List[Hypothesis]): The running hypotheses in beam search.
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
Returns:
List[Hypothesis]: The new running hypotheses.
"""
logging
.
debug
(
f
"the number of running hypotheses:
{
len
(
running_hyps
)
}
"
)
if
self
.
token_list
is
not
None
:
logging
.
debug
(
"best hypo: "
+
""
.
join
([
self
.
token_list
[
x
]
for
x
in
running_hyps
[
0
].
yseq
[
1
:]])
)
# add eos in the final loop to avoid that there are no ended hyps
if
i
==
maxlen
-
1
:
logging
.
info
(
"adding <eos> in the last position in the loop"
)
running_hyps
=
[
h
.
_replace
(
yseq
=
self
.
append_token
(
h
.
yseq
,
self
.
eos
))
for
h
in
running_hyps
]
# add ended hypotheses to a final list, and removed them from current hypotheses
# (this will be a problem, number of hyps < beam)
remained_hyps
=
[]
for
hyp
in
running_hyps
:
if
hyp
.
yseq
[
-
1
]
==
self
.
eos
:
# e.g., Word LM needs to add final <eos> score
for
k
,
d
in
chain
(
self
.
full_scorers
.
items
(),
self
.
part_scorers
.
items
()):
s
=
d
.
final_score
(
hyp
.
states
[
k
])
hyp
.
scores
[
k
]
+=
s
hyp
=
hyp
.
_replace
(
score
=
hyp
.
score
+
self
.
weights
[
k
]
*
s
)
ended_hyps
.
append
(
hyp
)
else
:
remained_hyps
.
append
(
hyp
)
return
remained_hyps
def
beam_search
(
x
:
torch
.
Tensor
,
sos
:
int
,
eos
:
int
,
beam_size
:
int
,
vocab_size
:
int
,
scorers
:
Dict
[
str
,
ScorerInterface
],
weights
:
Dict
[
str
,
float
],
token_list
:
List
[
str
]
=
None
,
maxlenratio
:
float
=
0.0
,
minlenratio
:
float
=
0.0
,
pre_beam_ratio
:
float
=
1.5
,
pre_beam_score_key
:
str
=
"full"
,
)
->
list
:
"""Perform beam search with scorers.
Args:
x (torch.Tensor): Encoded speech feature (T, D)
sos (int): Start of sequence id
eos (int): End of sequence id
beam_size (int): The number of hypotheses kept during search
vocab_size (int): The number of vocabulary
scorers (dict[str, ScorerInterface]): Dict of decoder modules
e.g., Decoder, CTCPrefixScorer, LM
The scorer will be ignored if it is `None`
weights (dict[str, float]): Dict of weights for each scorers
The scorer will be ignored if its weight is 0
token_list (list[str]): List of tokens for debug log
maxlenratio (float): Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths
minlenratio (float): Input length ratio to obtain min output length.
pre_beam_score_key (str): key of scores to perform pre-beam search
pre_beam_ratio (float): beam size in the pre-beam search
will be `int(pre_beam_ratio * beam_size)`
Returns:
list: N-best decoding results
"""
ret
=
BeamSearch
(
scorers
,
weights
,
beam_size
=
beam_size
,
vocab_size
=
vocab_size
,
pre_beam_ratio
=
pre_beam_ratio
,
pre_beam_score_key
=
pre_beam_score_key
,
sos
=
sos
,
eos
=
eos
,
token_list
=
token_list
,
).
forward
(
x
=
x
,
maxlenratio
=
maxlenratio
,
minlenratio
=
minlenratio
)
return
[
h
.
asdict
()
for
h
in
ret
]
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/beam_search_timesync.py
0 → 100644
View file @
60a2c57a
"""
Time Synchronous One-Pass Beam Search.
Implements joint CTC/attention decoding where
hypotheses are expanded along the time (input) axis,
as described in https://arxiv.org/abs/2210.05200.
Supports CPU and GPU inference.
References: https://arxiv.org/abs/1408.2873 for CTC beam search
Author: Brian Yan
"""
import
logging
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Tuple
import
numpy
as
np
import
torch
from
espnet.nets.beam_search
import
Hypothesis
from
espnet.nets.scorer_interface
import
ScorerInterface
@
dataclass
class
CacheItem
:
"""For caching attentional decoder and LM states."""
state
:
Any
scores
:
Any
log_sum
:
float
class
BeamSearchTimeSync
(
torch
.
nn
.
Module
):
"""Time synchronous beam search algorithm."""
def
__init__
(
self
,
sos
:
int
,
beam_size
:
int
,
scorers
:
Dict
[
str
,
ScorerInterface
],
weights
:
Dict
[
str
,
float
],
token_list
=
dict
,
pre_beam_ratio
:
float
=
1.5
,
blank
:
int
=
0
,
force_lid
:
bool
=
False
,
temp
:
float
=
1.0
,
):
"""Initialize beam search.
Args:
beam_size: num hyps
sos: sos index
ctc: CTC module
pre_beam_ratio: pre_beam_ratio * beam_size = pre_beam
pre_beam is used to select candidates from vocab to extend hypotheses
decoder: decoder ScorerInterface
ctc_weight: ctc_weight
blank: blank index
"""
super
().
__init__
()
self
.
ctc
=
scorers
[
"ctc"
]
self
.
decoder
=
scorers
[
"decoder"
]
self
.
lm
=
scorers
[
"lm"
]
if
"lm"
in
scorers
else
None
self
.
beam_size
=
beam_size
self
.
pre_beam_size
=
int
(
pre_beam_ratio
*
beam_size
)
self
.
ctc_weight
=
weights
[
"ctc"
]
self
.
lm_weight
=
weights
[
"lm"
]
self
.
decoder_weight
=
weights
[
"decoder"
]
self
.
penalty
=
weights
[
"length_bonus"
]
self
.
sos
=
sos
self
.
sos_th
=
torch
.
tensor
([
self
.
sos
])
self
.
blank
=
blank
self
.
attn_cache
=
dict
()
# cache for p_attn(Y|X)
self
.
lm_cache
=
dict
()
# cache for p_lm(Y)
self
.
enc_output
=
None
# log p_ctc(Z|X)
self
.
force_lid
=
force_lid
self
.
temp
=
temp
self
.
token_list
=
token_list
def
reset
(
self
,
enc_output
:
torch
.
Tensor
):
"""Reset object for a new utterance."""
self
.
attn_cache
=
dict
()
self
.
lm_cache
=
dict
()
self
.
enc_output
=
enc_output
self
.
sos_th
=
self
.
sos_th
.
to
(
enc_output
.
device
)
if
self
.
decoder
is
not
None
:
init_decoder_state
=
self
.
decoder
.
init_state
(
enc_output
)
decoder_scores
,
decoder_state
=
self
.
decoder
.
score
(
self
.
sos_th
,
init_decoder_state
,
enc_output
)
self
.
attn_cache
[(
self
.
sos
,)]
=
CacheItem
(
state
=
decoder_state
,
scores
=
decoder_scores
,
log_sum
=
0.0
,
)
if
self
.
lm
is
not
None
:
init_lm_state
=
self
.
lm
.
init_state
(
enc_output
)
lm_scores
,
lm_state
=
self
.
lm
.
score
(
self
.
sos_th
,
init_lm_state
,
enc_output
)
self
.
lm_cache
[(
self
.
sos
,)]
=
CacheItem
(
state
=
lm_state
,
scores
=
lm_scores
,
log_sum
=
0.0
,
)
def
cached_score
(
self
,
h
:
Tuple
[
int
],
cache
:
dict
,
scorer
:
ScorerInterface
)
->
Any
:
"""Retrieve decoder/LM scores which may be cached."""
root
=
h
[:
-
1
]
# prefix
if
root
in
cache
:
root_scores
=
cache
[
root
].
scores
root_state
=
cache
[
root
].
state
root_log_sum
=
cache
[
root
].
log_sum
else
:
# run decoder fwd one step and update cache
root_root
=
root
[:
-
1
]
root_root_state
=
cache
[
root_root
].
state
root_scores
,
root_state
=
scorer
.
score
(
torch
.
tensor
(
root
,
device
=
self
.
enc_output
.
device
).
long
(),
root_root_state
,
self
.
enc_output
,
)
root_log_sum
=
cache
[
root_root
].
log_sum
+
float
(
cache
[
root_root
].
scores
[
root
[
-
1
]]
)
cache
[
root
]
=
CacheItem
(
state
=
root_state
,
scores
=
root_scores
,
log_sum
=
root_log_sum
)
cand_score
=
float
(
root_scores
[
h
[
-
1
]])
score
=
root_log_sum
+
cand_score
return
score
def
joint_score
(
self
,
hyps
:
Any
,
ctc_score_dp
:
Any
)
->
Any
:
"""Calculate joint score for hyps."""
scores
=
dict
()
for
h
in
hyps
:
score
=
self
.
ctc_weight
*
np
.
logaddexp
(
*
ctc_score_dp
[
h
])
# ctc score
if
len
(
h
)
>
1
and
self
.
decoder_weight
>
0
and
self
.
decoder
is
not
None
:
score
+=
(
self
.
cached_score
(
h
,
self
.
attn_cache
,
self
.
decoder
)
*
self
.
decoder_weight
)
# attn score
if
len
(
h
)
>
1
and
self
.
lm
is
not
None
and
self
.
lm_weight
>
0
:
score
+=
(
self
.
cached_score
(
h
,
self
.
lm_cache
,
self
.
lm
)
*
self
.
lm_weight
)
# lm score
score
+=
self
.
penalty
*
(
len
(
h
)
-
1
)
# penalty score
scores
[
h
]
=
score
return
scores
def
time_step
(
self
,
p_ctc
:
Any
,
ctc_score_dp
:
Any
,
hyps
:
Any
)
->
Any
:
"""Execute a single time step."""
pre_beam_threshold
=
np
.
sort
(
p_ctc
)[
-
self
.
pre_beam_size
]
cands
=
set
(
np
.
where
(
p_ctc
>=
pre_beam_threshold
)[
0
])
if
len
(
cands
)
==
0
:
cands
=
{
np
.
argmax
(
p_ctc
)}
new_hyps
=
set
()
ctc_score_dp_next
=
defaultdict
(
lambda
:
(
float
(
"-inf"
),
float
(
"-inf"
))
)
# (p_nb, p_b)
tmp
=
[]
for
hyp_l
in
hyps
:
p_prev_l
=
np
.
logaddexp
(
*
ctc_score_dp
[
hyp_l
])
for
c
in
cands
:
if
c
==
self
.
blank
:
logging
.
debug
(
"blank cand, hypothesis is "
+
str
(
hyp_l
))
p_nb
,
p_b
=
ctc_score_dp_next
[
hyp_l
]
p_b
=
np
.
logaddexp
(
p_b
,
p_ctc
[
c
]
+
p_prev_l
)
ctc_score_dp_next
[
hyp_l
]
=
(
p_nb
,
p_b
)
new_hyps
.
add
(
hyp_l
)
else
:
l_plus
=
hyp_l
+
(
int
(
c
),)
logging
.
debug
(
"non-blank cand, hypothesis is "
+
str
(
l_plus
))
p_nb
,
p_b
=
ctc_score_dp_next
[
l_plus
]
if
c
==
hyp_l
[
-
1
]:
logging
.
debug
(
"repeat cand, hypothesis is "
+
str
(
hyp_l
))
p_nb_prev
,
p_b_prev
=
ctc_score_dp
[
hyp_l
]
p_nb
=
np
.
logaddexp
(
p_nb
,
p_ctc
[
c
]
+
p_b_prev
)
p_nb_l
,
p_b_l
=
ctc_score_dp_next
[
hyp_l
]
p_nb_l
=
np
.
logaddexp
(
p_nb_l
,
p_ctc
[
c
]
+
p_nb_prev
)
ctc_score_dp_next
[
hyp_l
]
=
(
p_nb_l
,
p_b_l
)
else
:
p_nb
=
np
.
logaddexp
(
p_nb
,
p_ctc
[
c
]
+
p_prev_l
)
if
l_plus
not
in
hyps
and
l_plus
in
ctc_score_dp
:
p_b
=
np
.
logaddexp
(
p_b
,
p_ctc
[
self
.
blank
]
+
np
.
logaddexp
(
*
ctc_score_dp
[
l_plus
])
)
p_nb
=
np
.
logaddexp
(
p_nb
,
p_ctc
[
c
]
+
ctc_score_dp
[
l_plus
][
0
])
tmp
.
append
(
l_plus
)
ctc_score_dp_next
[
l_plus
]
=
(
p_nb
,
p_b
)
new_hyps
.
add
(
l_plus
)
scores
=
self
.
joint_score
(
new_hyps
,
ctc_score_dp_next
)
hyps
=
sorted
(
new_hyps
,
key
=
lambda
ll
:
scores
[
ll
],
reverse
=
True
)[
:
self
.
beam_size
]
ctc_score_dp
=
ctc_score_dp_next
.
copy
()
return
ctc_score_dp
,
hyps
,
scores
def
forward
(
self
,
x
:
torch
.
Tensor
,
maxlenratio
:
float
=
0.0
,
minlenratio
:
float
=
0.0
)
->
List
[
Hypothesis
]:
"""Perform beam search.
Args:
enc_output (torch.Tensor)
Return:
list[Hypothesis]
"""
logging
.
info
(
"decoder input lengths: "
+
str
(
x
.
shape
[
0
]))
lpz
=
self
.
ctc
.
log_softmax
(
x
.
unsqueeze
(
0
))
lpz
=
lpz
.
squeeze
(
0
)
lpz
=
lpz
.
cpu
().
detach
().
numpy
()
self
.
reset
(
x
)
hyps
=
[(
self
.
sos
,)]
ctc_score_dp
=
defaultdict
(
lambda
:
(
float
(
"-inf"
),
float
(
"-inf"
))
)
# (p_nb, p_b) - dp object tracking p_ctc
ctc_score_dp
[(
self
.
sos
,)]
=
(
float
(
"-inf"
),
0.0
)
for
t
in
range
(
lpz
.
shape
[
0
]):
logging
.
debug
(
"position "
+
str
(
t
))
ctc_score_dp
,
hyps
,
scores
=
self
.
time_step
(
lpz
[
t
,
:],
ctc_score_dp
,
hyps
)
ret
=
[
Hypothesis
(
yseq
=
torch
.
tensor
(
list
(
h
)
+
[
self
.
sos
]),
score
=
scores
[
h
])
for
h
in
hyps
]
best_hyp
=
""
.
join
([
self
.
token_list
[
x
]
for
x
in
ret
[
0
].
yseq
.
tolist
()])
best_hyp_len
=
len
(
ret
[
0
].
yseq
)
best_score
=
ret
[
0
].
score
logging
.
info
(
f
"output length:
{
best_hyp_len
}
"
)
logging
.
info
(
f
"total log probability:
{
best_score
:.
2
f
}
"
)
logging
.
info
(
f
"best hypo:
{
best_hyp
}
"
)
return
ret
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/beam_search_transducer.py
0 → 100644
View file @
60a2c57a
"""Search algorithms for Transducer models."""
import
logging
from
typing
import
List
,
Union
import
numpy
as
np
import
torch
from
espnet.nets.pytorch_backend.transducer.custom_decoder
import
CustomDecoder
from
espnet.nets.pytorch_backend.transducer.joint_network
import
JointNetwork
from
espnet.nets.pytorch_backend.transducer.rnn_decoder
import
RNNDecoder
from
espnet.nets.pytorch_backend.transducer.utils
import
(
create_lm_batch_states
,
init_lm_state
,
is_prefix
,
recombine_hyps
,
select_k_expansions
,
select_lm_state
,
subtract
,
)
from
espnet.nets.transducer_decoder_interface
import
ExtendedHypothesis
,
Hypothesis
class
BeamSearchTransducer
:
"""Beam search implementation for Transducer."""
def
__init__
(
self
,
decoder
:
Union
[
RNNDecoder
,
CustomDecoder
],
joint_network
:
JointNetwork
,
beam_size
:
int
,
lm
:
torch
.
nn
.
Module
=
None
,
lm_weight
:
float
=
0.1
,
search_type
:
str
=
"default"
,
max_sym_exp
:
int
=
2
,
u_max
:
int
=
50
,
nstep
:
int
=
1
,
prefix_alpha
:
int
=
1
,
expansion_gamma
:
int
=
2.3
,
expansion_beta
:
int
=
2
,
score_norm
:
bool
=
True
,
softmax_temperature
:
float
=
1.0
,
nbest
:
int
=
1
,
quantization
:
bool
=
False
,
):
"""Initialize Transducer search module.
Args:
decoder: Decoder module.
joint_network: Joint network module.
beam_size: Beam size.
lm: LM class.
lm_weight: LM weight for soft fusion.
search_type: Search algorithm to use during inference.
max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
u_max: Maximum output sequence length. (ALSD)
nstep: Number of maximum expansion steps at each time step. (NSC/mAES)
prefix_alpha: Maximum prefix length in prefix search. (NSC/mAES)
expansion_beta:
Number of additional candidates for expanded hypotheses selection. (mAES)
expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
score_norm: Normalize final scores by length. ("default")
softmax_temperature: Penalization term for softmax function.
nbest: Number of final hypothesis.
quantization: Whether dynamic quantization is used.
"""
self
.
decoder
=
decoder
self
.
joint_network
=
joint_network
self
.
beam_size
=
beam_size
self
.
hidden_size
=
decoder
.
dunits
self
.
vocab_size
=
decoder
.
odim
self
.
blank_id
=
decoder
.
blank_id
if
self
.
beam_size
<=
1
:
self
.
search_algorithm
=
self
.
greedy_search
elif
search_type
==
"default"
:
self
.
search_algorithm
=
self
.
default_beam_search
elif
search_type
==
"tsd"
:
self
.
max_sym_exp
=
max_sym_exp
self
.
search_algorithm
=
self
.
time_sync_decoding
elif
search_type
==
"alsd"
:
self
.
u_max
=
u_max
self
.
search_algorithm
=
self
.
align_length_sync_decoding
elif
search_type
==
"nsc"
:
self
.
nstep
=
nstep
self
.
prefix_alpha
=
prefix_alpha
self
.
search_algorithm
=
self
.
nsc_beam_search
elif
search_type
==
"maes"
:
self
.
nstep
=
nstep
if
nstep
>
1
else
2
self
.
prefix_alpha
=
prefix_alpha
self
.
expansion_gamma
=
expansion_gamma
assert
self
.
vocab_size
>=
beam_size
+
expansion_beta
,
(
"beam_size (%d) + expansion_beta (%d) "
"should be smaller or equal to vocabulary size (%d)."
%
(
beam_size
,
expansion_beta
,
self
.
vocab_size
)
)
self
.
max_candidates
=
beam_size
+
expansion_beta
self
.
search_algorithm
=
self
.
modified_adaptive_expansion_search
else
:
raise
NotImplementedError
if
lm
is
not
None
:
self
.
use_lm
=
True
self
.
lm
=
lm
self
.
is_wordlm
=
True
if
hasattr
(
lm
.
predictor
,
"wordlm"
)
else
False
self
.
lm_predictor
=
lm
.
predictor
.
wordlm
if
self
.
is_wordlm
else
lm
.
predictor
self
.
lm_layers
=
len
(
self
.
lm_predictor
.
rnn
)
self
.
lm_weight
=
lm_weight
else
:
self
.
use_lm
=
False
if
softmax_temperature
>
1.0
and
lm
is
not
None
:
logging
.
warning
(
"Softmax temperature is not supported with LM decoding."
"Setting softmax-temperature value to 1.0."
)
self
.
softmax_temperature
=
1.0
else
:
self
.
softmax_temperature
=
softmax_temperature
self
.
quantization
=
quantization
self
.
score_norm
=
score_norm
self
.
nbest
=
nbest
def
__call__
(
self
,
enc_out
:
torch
.
Tensor
)
->
Union
[
List
[
Hypothesis
],
List
[
ExtendedHypothesis
]]:
"""Perform beam search.
Args:
enc_out: Encoder output sequence. (T, D_enc)
Returns:
nbest_hyps: N-best decoding results
"""
self
.
decoder
.
set_device
(
enc_out
.
device
)
nbest_hyps
=
self
.
search_algorithm
(
enc_out
)
return
nbest_hyps
def
sort_nbest
(
self
,
hyps
:
Union
[
List
[
Hypothesis
],
List
[
ExtendedHypothesis
]]
)
->
Union
[
List
[
Hypothesis
],
List
[
ExtendedHypothesis
]]:
"""Sort hypotheses by score or score given sequence length.
Args:
hyps: Hypothesis.
Return:
hyps: Sorted hypothesis.
"""
if
self
.
score_norm
:
hyps
.
sort
(
key
=
lambda
x
:
x
.
score
/
len
(
x
.
yseq
),
reverse
=
True
)
else
:
hyps
.
sort
(
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)
return
hyps
[:
self
.
nbest
]
def
prefix_search
(
self
,
hyps
:
List
[
ExtendedHypothesis
],
enc_out_t
:
torch
.
Tensor
)
->
List
[
ExtendedHypothesis
]:
"""Prefix search for NSC and mAES strategies.
Based on https://arxiv.org/pdf/1211.3711.pdf
"""
for
j
,
hyp_j
in
enumerate
(
hyps
[:
-
1
]):
for
hyp_i
in
hyps
[(
j
+
1
)
:]:
curr_id
=
len
(
hyp_j
.
yseq
)
pref_id
=
len
(
hyp_i
.
yseq
)
if
(
is_prefix
(
hyp_j
.
yseq
,
hyp_i
.
yseq
)
and
(
curr_id
-
pref_id
)
<=
self
.
prefix_alpha
):
logp
=
torch
.
log_softmax
(
self
.
joint_network
(
enc_out_t
,
hyp_i
.
dec_out
[
-
1
],
quantization
=
self
.
quantization
)
/
self
.
softmax_temperature
,
dim
=-
1
,
)
curr_score
=
hyp_i
.
score
+
float
(
logp
[
hyp_j
.
yseq
[
pref_id
]])
for
k
in
range
(
pref_id
,
(
curr_id
-
1
)):
logp
=
torch
.
log_softmax
(
self
.
joint_network
(
enc_out_t
,
hyp_j
.
dec_out
[
k
],
quantization
=
self
.
quantization
,
)
/
self
.
softmax_temperature
,
dim
=-
1
,
)
curr_score
+=
float
(
logp
[
hyp_j
.
yseq
[
k
+
1
]])
hyp_j
.
score
=
np
.
logaddexp
(
hyp_j
.
score
,
curr_score
)
return
hyps
def
greedy_search
(
self
,
enc_out
:
torch
.
Tensor
)
->
List
[
Hypothesis
]:
"""Greedy search implementation.
Args:
enc_out: Encoder output sequence. (T, D_enc)
Returns:
hyp: 1-best hypotheses.
"""
dec_state
=
self
.
decoder
.
init_state
(
1
)
hyp
=
Hypothesis
(
score
=
0.0
,
yseq
=
[
self
.
blank_id
],
dec_state
=
dec_state
)
cache
=
{}
dec_out
,
state
,
_
=
self
.
decoder
.
score
(
hyp
,
cache
)
for
enc_out_t
in
enc_out
:
logp
=
torch
.
log_softmax
(
self
.
joint_network
(
enc_out_t
,
dec_out
,
quantization
=
self
.
quantization
)
/
self
.
softmax_temperature
,
dim
=-
1
,
)
top_logp
,
pred
=
torch
.
max
(
logp
,
dim
=-
1
)
if
pred
!=
self
.
blank_id
:
hyp
.
yseq
.
append
(
int
(
pred
))
hyp
.
score
+=
float
(
top_logp
)
hyp
.
dec_state
=
state
dec_out
,
state
,
_
=
self
.
decoder
.
score
(
hyp
,
cache
)
return
[
hyp
]
def
default_beam_search
(
self
,
enc_out
:
torch
.
Tensor
)
->
List
[
Hypothesis
]:
"""Beam search implementation.
Modified from https://arxiv.org/pdf/1211.3711.pdf
Args:
enc_out: Encoder output sequence. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam
=
min
(
self
.
beam_size
,
self
.
vocab_size
)
beam_k
=
min
(
beam
,
(
self
.
vocab_size
-
1
))
dec_state
=
self
.
decoder
.
init_state
(
1
)
kept_hyps
=
[
Hypothesis
(
score
=
0.0
,
yseq
=
[
self
.
blank_id
],
dec_state
=
dec_state
)]
cache
=
{}
for
enc_out_t
in
enc_out
:
hyps
=
kept_hyps
kept_hyps
=
[]
while
True
:
max_hyp
=
max
(
hyps
,
key
=
lambda
x
:
x
.
score
)
hyps
.
remove
(
max_hyp
)
dec_out
,
state
,
lm_tokens
=
self
.
decoder
.
score
(
max_hyp
,
cache
)
logp
=
torch
.
log_softmax
(
self
.
joint_network
(
enc_out_t
,
dec_out
,
quantization
=
self
.
quantization
)
/
self
.
softmax_temperature
,
dim
=-
1
,
)
top_k
=
logp
[
1
:].
topk
(
beam_k
,
dim
=-
1
)
kept_hyps
.
append
(
Hypothesis
(
score
=
(
max_hyp
.
score
+
float
(
logp
[
0
:
1
])),
yseq
=
max_hyp
.
yseq
[:],
dec_state
=
max_hyp
.
dec_state
,
lm_state
=
max_hyp
.
lm_state
,
)
)
if
self
.
use_lm
:
lm_state
,
lm_scores
=
self
.
lm
.
predict
(
max_hyp
.
lm_state
,
lm_tokens
)
else
:
lm_state
=
max_hyp
.
lm_state
for
logp
,
k
in
zip
(
*
top_k
):
score
=
max_hyp
.
score
+
float
(
logp
)
if
self
.
use_lm
:
score
+=
self
.
lm_weight
*
lm_scores
[
0
][
k
+
1
]
hyps
.
append
(
Hypothesis
(
score
=
score
,
yseq
=
max_hyp
.
yseq
[:]
+
[
int
(
k
+
1
)],
dec_state
=
state
,
lm_state
=
lm_state
,
)
)
hyps_max
=
float
(
max
(
hyps
,
key
=
lambda
x
:
x
.
score
).
score
)
kept_most_prob
=
sorted
(
[
hyp
for
hyp
in
kept_hyps
if
hyp
.
score
>
hyps_max
],
key
=
lambda
x
:
x
.
score
,
)
if
len
(
kept_most_prob
)
>=
beam
:
kept_hyps
=
kept_most_prob
break
return
self
.
sort_nbest
(
kept_hyps
)
def
time_sync_decoding
(
self
,
enc_out
:
torch
.
Tensor
)
->
List
[
Hypothesis
]:
"""Time synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
enc_out: Encoder output sequence. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam
=
min
(
self
.
beam_size
,
self
.
vocab_size
)
beam_state
=
self
.
decoder
.
init_state
(
beam
)
B
=
[
Hypothesis
(
yseq
=
[
self
.
blank_id
],
score
=
0.0
,
dec_state
=
self
.
decoder
.
select_state
(
beam_state
,
0
),
)
]
cache
=
{}
if
self
.
use_lm
and
not
self
.
is_wordlm
:
B
[
0
].
lm_state
=
init_lm_state
(
self
.
lm_predictor
)
for
enc_out_t
in
enc_out
:
A
=
[]
C
=
B
enc_out_t
=
enc_out_t
.
unsqueeze
(
0
)
for
v
in
range
(
self
.
max_sym_exp
):
D
=
[]
beam_dec_out
,
beam_state
,
beam_lm_tokens
=
self
.
decoder
.
batch_score
(
C
,
beam_state
,
cache
,
self
.
use_lm
,
)
beam_logp
=
torch
.
log_softmax
(
self
.
joint_network
(
enc_out_t
,
beam_dec_out
)
/
self
.
softmax_temperature
,
dim
=-
1
,
)
beam_topk
=
beam_logp
[:,
1
:].
topk
(
beam
,
dim
=-
1
)
seq_A
=
[
h
.
yseq
for
h
in
A
]
for
i
,
hyp
in
enumerate
(
C
):
if
hyp
.
yseq
not
in
seq_A
:
A
.
append
(
Hypothesis
(
score
=
(
hyp
.
score
+
float
(
beam_logp
[
i
,
0
])),
yseq
=
hyp
.
yseq
[:],
dec_state
=
hyp
.
dec_state
,
lm_state
=
hyp
.
lm_state
,
)
)
else
:
dict_pos
=
seq_A
.
index
(
hyp
.
yseq
)
A
[
dict_pos
].
score
=
np
.
logaddexp
(
A
[
dict_pos
].
score
,
(
hyp
.
score
+
float
(
beam_logp
[
i
,
0
]))
)
if
v
<
(
self
.
max_sym_exp
-
1
):
if
self
.
use_lm
:
beam_lm_states
=
create_lm_batch_states
(
[
c
.
lm_state
for
c
in
C
],
self
.
lm_layers
,
self
.
is_wordlm
)
beam_lm_states
,
beam_lm_scores
=
self
.
lm
.
buff_predict
(
beam_lm_states
,
beam_lm_tokens
,
len
(
C
)
)
for
i
,
hyp
in
enumerate
(
C
):
for
logp
,
k
in
zip
(
beam_topk
[
0
][
i
],
beam_topk
[
1
][
i
]
+
1
):
new_hyp
=
Hypothesis
(
score
=
(
hyp
.
score
+
float
(
logp
)),
yseq
=
(
hyp
.
yseq
+
[
int
(
k
)]),
dec_state
=
self
.
decoder
.
select_state
(
beam_state
,
i
),
lm_state
=
hyp
.
lm_state
,
)
if
self
.
use_lm
:
new_hyp
.
score
+=
self
.
lm_weight
*
beam_lm_scores
[
i
,
k
]
new_hyp
.
lm_state
=
select_lm_state
(
beam_lm_states
,
i
,
self
.
lm_layers
,
self
.
is_wordlm
)
D
.
append
(
new_hyp
)
C
=
sorted
(
D
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)[:
beam
]
B
=
sorted
(
A
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)[:
beam
]
return
self
.
sort_nbest
(
B
)
def
align_length_sync_decoding
(
self
,
enc_out
:
torch
.
Tensor
)
->
List
[
Hypothesis
]:
"""Alignment-length synchronous beam search implementation.
Based on https://ieeexplore.ieee.org/document/9053040
Args:
h: Encoder output sequences. (T, D)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam
=
min
(
self
.
beam_size
,
self
.
vocab_size
)
t_max
=
int
(
enc_out
.
size
(
0
))
u_max
=
min
(
self
.
u_max
,
(
t_max
-
1
))
beam_state
=
self
.
decoder
.
init_state
(
beam
)
B
=
[
Hypothesis
(
yseq
=
[
self
.
blank_id
],
score
=
0.0
,
dec_state
=
self
.
decoder
.
select_state
(
beam_state
,
0
),
)
]
final
=
[]
cache
=
{}
if
self
.
use_lm
and
not
self
.
is_wordlm
:
B
[
0
].
lm_state
=
init_lm_state
(
self
.
lm_predictor
)
for
i
in
range
(
t_max
+
u_max
):
A
=
[]
B_
=
[]
B_enc_out
=
[]
for
hyp
in
B
:
u
=
len
(
hyp
.
yseq
)
-
1
t
=
i
-
u
if
t
>
(
t_max
-
1
):
continue
B_
.
append
(
hyp
)
B_enc_out
.
append
((
t
,
enc_out
[
t
]))
if
B_
:
beam_dec_out
,
beam_state
,
beam_lm_tokens
=
self
.
decoder
.
batch_score
(
B_
,
beam_state
,
cache
,
self
.
use_lm
,
)
beam_enc_out
=
torch
.
stack
([
x
[
1
]
for
x
in
B_enc_out
])
beam_logp
=
torch
.
log_softmax
(
self
.
joint_network
(
beam_enc_out
,
beam_dec_out
)
/
self
.
softmax_temperature
,
dim
=-
1
,
)
beam_topk
=
beam_logp
[:,
1
:].
topk
(
beam
,
dim
=-
1
)
if
self
.
use_lm
:
beam_lm_states
=
create_lm_batch_states
(
[
b
.
lm_state
for
b
in
B_
],
self
.
lm_layers
,
self
.
is_wordlm
)
beam_lm_states
,
beam_lm_scores
=
self
.
lm
.
buff_predict
(
beam_lm_states
,
beam_lm_tokens
,
len
(
B_
)
)
for
i
,
hyp
in
enumerate
(
B_
):
new_hyp
=
Hypothesis
(
score
=
(
hyp
.
score
+
float
(
beam_logp
[
i
,
0
])),
yseq
=
hyp
.
yseq
[:],
dec_state
=
hyp
.
dec_state
,
lm_state
=
hyp
.
lm_state
,
)
A
.
append
(
new_hyp
)
if
B_enc_out
[
i
][
0
]
==
(
t_max
-
1
):
final
.
append
(
new_hyp
)
for
logp
,
k
in
zip
(
beam_topk
[
0
][
i
],
beam_topk
[
1
][
i
]
+
1
):
new_hyp
=
Hypothesis
(
score
=
(
hyp
.
score
+
float
(
logp
)),
yseq
=
(
hyp
.
yseq
[:]
+
[
int
(
k
)]),
dec_state
=
self
.
decoder
.
select_state
(
beam_state
,
i
),
lm_state
=
hyp
.
lm_state
,
)
if
self
.
use_lm
:
new_hyp
.
score
+=
self
.
lm_weight
*
beam_lm_scores
[
i
,
k
]
new_hyp
.
lm_state
=
select_lm_state
(
beam_lm_states
,
i
,
self
.
lm_layers
,
self
.
is_wordlm
)
A
.
append
(
new_hyp
)
B
=
sorted
(
A
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)[:
beam
]
B
=
recombine_hyps
(
B
)
if
final
:
return
self
.
sort_nbest
(
final
)
else
:
return
B
def
nsc_beam_search
(
self
,
enc_out
:
torch
.
Tensor
)
->
List
[
ExtendedHypothesis
]:
"""N-step constrained beam search implementation.
Based on/Modified from https://arxiv.org/pdf/2002.03577.pdf.
Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet
until further modifications.
Args:
enc_out: Encoder output sequence. (T, D_enc)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam
=
min
(
self
.
beam_size
,
self
.
vocab_size
)
beam_k
=
min
(
beam
,
(
self
.
vocab_size
-
1
))
beam_state
=
self
.
decoder
.
init_state
(
beam
)
init_tokens
=
[
ExtendedHypothesis
(
yseq
=
[
self
.
blank_id
],
score
=
0.0
,
dec_state
=
self
.
decoder
.
select_state
(
beam_state
,
0
),
)
]
cache
=
{}
beam_dec_out
,
beam_state
,
beam_lm_tokens
=
self
.
decoder
.
batch_score
(
init_tokens
,
beam_state
,
cache
,
self
.
use_lm
,
)
state
=
self
.
decoder
.
select_state
(
beam_state
,
0
)
if
self
.
use_lm
:
beam_lm_states
,
beam_lm_scores
=
self
.
lm
.
buff_predict
(
None
,
beam_lm_tokens
,
1
)
lm_state
=
select_lm_state
(
beam_lm_states
,
0
,
self
.
lm_layers
,
self
.
is_wordlm
)
lm_scores
=
beam_lm_scores
[
0
]
else
:
lm_state
=
None
lm_scores
=
None
kept_hyps
=
[
ExtendedHypothesis
(
yseq
=
[
self
.
blank_id
],
score
=
0.0
,
dec_state
=
state
,
dec_out
=
[
beam_dec_out
[
0
]],
lm_state
=
lm_state
,
lm_scores
=
lm_scores
,
)
]
for
enc_out_t
in
enc_out
:
hyps
=
self
.
prefix_search
(
sorted
(
kept_hyps
,
key
=
lambda
x
:
len
(
x
.
yseq
),
reverse
=
True
),
enc_out_t
,
)
kept_hyps
=
[]
beam_enc_out
=
enc_out_t
.
unsqueeze
(
0
)
S
=
[]
V
=
[]
for
n
in
range
(
self
.
nstep
):
beam_dec_out
=
torch
.
stack
([
hyp
.
dec_out
[
-
1
]
for
hyp
in
hyps
])
beam_logp
=
torch
.
log_softmax
(
self
.
joint_network
(
beam_enc_out
,
beam_dec_out
)
/
self
.
softmax_temperature
,
dim
=-
1
,
)
beam_topk
=
beam_logp
[:,
1
:].
topk
(
beam_k
,
dim
=-
1
)
for
i
,
hyp
in
enumerate
(
hyps
):
S
.
append
(
ExtendedHypothesis
(
yseq
=
hyp
.
yseq
[:],
score
=
hyp
.
score
+
float
(
beam_logp
[
i
,
0
:
1
]),
dec_out
=
hyp
.
dec_out
[:],
dec_state
=
hyp
.
dec_state
,
lm_state
=
hyp
.
lm_state
,
lm_scores
=
hyp
.
lm_scores
,
)
)
for
logp
,
k
in
zip
(
beam_topk
[
0
][
i
],
beam_topk
[
1
][
i
]
+
1
):
score
=
hyp
.
score
+
float
(
logp
)
if
self
.
use_lm
:
score
+=
self
.
lm_weight
*
float
(
hyp
.
lm_scores
[
k
])
V
.
append
(
ExtendedHypothesis
(
yseq
=
hyp
.
yseq
[:]
+
[
int
(
k
)],
score
=
score
,
dec_out
=
hyp
.
dec_out
[:],
dec_state
=
hyp
.
dec_state
,
lm_state
=
hyp
.
lm_state
,
lm_scores
=
hyp
.
lm_scores
,
)
)
V
.
sort
(
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)
V
=
subtract
(
V
,
hyps
)[:
beam
]
beam_state
=
self
.
decoder
.
create_batch_states
(
beam_state
,
[
v
.
dec_state
for
v
in
V
],
[
v
.
yseq
for
v
in
V
],
)
beam_dec_out
,
beam_state
,
beam_lm_tokens
=
self
.
decoder
.
batch_score
(
V
,
beam_state
,
cache
,
self
.
use_lm
,
)
if
self
.
use_lm
:
beam_lm_states
=
create_lm_batch_states
(
[
v
.
lm_state
for
v
in
V
],
self
.
lm_layers
,
self
.
is_wordlm
)
beam_lm_states
,
beam_lm_scores
=
self
.
lm
.
buff_predict
(
beam_lm_states
,
beam_lm_tokens
,
len
(
V
)
)
if
n
<
(
self
.
nstep
-
1
):
for
i
,
v
in
enumerate
(
V
):
v
.
dec_out
.
append
(
beam_dec_out
[
i
])
v
.
dec_state
=
self
.
decoder
.
select_state
(
beam_state
,
i
)
if
self
.
use_lm
:
v
.
lm_state
=
select_lm_state
(
beam_lm_states
,
i
,
self
.
lm_layers
,
self
.
is_wordlm
)
v
.
lm_scores
=
beam_lm_scores
[
i
]
hyps
=
V
[:]
else
:
beam_logp
=
torch
.
log_softmax
(
self
.
joint_network
(
beam_enc_out
,
beam_dec_out
)
/
self
.
softmax_temperature
,
dim
=-
1
,
)
for
i
,
v
in
enumerate
(
V
):
if
self
.
nstep
!=
1
:
v
.
score
+=
float
(
beam_logp
[
i
,
0
])
v
.
dec_out
.
append
(
beam_dec_out
[
i
])
v
.
dec_state
=
self
.
decoder
.
select_state
(
beam_state
,
i
)
if
self
.
use_lm
:
v
.
lm_state
=
select_lm_state
(
beam_lm_states
,
i
,
self
.
lm_layers
,
self
.
is_wordlm
)
v
.
lm_scores
=
beam_lm_scores
[
i
]
kept_hyps
=
sorted
((
S
+
V
),
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)[:
beam
]
return
self
.
sort_nbest
(
kept_hyps
)
def
modified_adaptive_expansion_search
(
self
,
enc_out
:
torch
.
Tensor
)
->
List
[
ExtendedHypothesis
]:
"""It's the modified Adaptive Expansion Search (mAES) implementation.
Based on/modified from https://ieeexplore.ieee.org/document/9250505 and NSC.
Args:
enc_out: Encoder output sequence. (T, D_enc)
Returns:
nbest_hyps: N-best hypothesis.
"""
beam
=
min
(
self
.
beam_size
,
self
.
vocab_size
)
beam_state
=
self
.
decoder
.
init_state
(
beam
)
init_tokens
=
[
ExtendedHypothesis
(
yseq
=
[
self
.
blank_id
],
score
=
0.0
,
dec_state
=
self
.
decoder
.
select_state
(
beam_state
,
0
),
)
]
cache
=
{}
beam_dec_out
,
beam_state
,
beam_lm_tokens
=
self
.
decoder
.
batch_score
(
init_tokens
,
beam_state
,
cache
,
self
.
use_lm
,
)
state
=
self
.
decoder
.
select_state
(
beam_state
,
0
)
if
self
.
use_lm
:
beam_lm_states
,
beam_lm_scores
=
self
.
lm
.
buff_predict
(
None
,
beam_lm_tokens
,
1
)
lm_state
=
select_lm_state
(
beam_lm_states
,
0
,
self
.
lm_layers
,
self
.
is_wordlm
)
lm_scores
=
beam_lm_scores
[
0
]
else
:
lm_state
=
None
lm_scores
=
None
kept_hyps
=
[
ExtendedHypothesis
(
yseq
=
[
self
.
blank_id
],
score
=
0.0
,
dec_state
=
state
,
dec_out
=
[
beam_dec_out
[
0
]],
lm_state
=
lm_state
,
lm_scores
=
lm_scores
,
)
]
for
enc_out_t
in
enc_out
:
hyps
=
self
.
prefix_search
(
sorted
(
kept_hyps
,
key
=
lambda
x
:
len
(
x
.
yseq
),
reverse
=
True
),
enc_out_t
,
)
kept_hyps
=
[]
beam_enc_out
=
enc_out_t
.
unsqueeze
(
0
)
list_b
=
[]
duplication_check
=
[
hyp
.
yseq
for
hyp
in
hyps
]
for
n
in
range
(
self
.
nstep
):
beam_dec_out
=
torch
.
stack
([
h
.
dec_out
[
-
1
]
for
h
in
hyps
])
beam_logp
,
beam_idx
=
torch
.
log_softmax
(
self
.
joint_network
(
beam_enc_out
,
beam_dec_out
)
/
self
.
softmax_temperature
,
dim
=-
1
,
).
topk
(
self
.
max_candidates
,
dim
=-
1
)
k_expansions
=
select_k_expansions
(
hyps
,
beam_idx
,
beam_logp
,
self
.
expansion_gamma
,
)
list_exp
=
[]
for
i
,
hyp
in
enumerate
(
hyps
):
for
k
,
new_score
in
k_expansions
[
i
]:
new_hyp
=
ExtendedHypothesis
(
yseq
=
hyp
.
yseq
[:],
score
=
new_score
,
dec_out
=
hyp
.
dec_out
[:],
dec_state
=
hyp
.
dec_state
,
lm_state
=
hyp
.
lm_state
,
lm_scores
=
hyp
.
lm_scores
,
)
if
k
==
0
:
list_b
.
append
(
new_hyp
)
else
:
if
new_hyp
.
yseq
+
[
int
(
k
)]
not
in
duplication_check
:
new_hyp
.
yseq
.
append
(
int
(
k
))
if
self
.
use_lm
:
new_hyp
.
score
+=
self
.
lm_weight
*
float
(
hyp
.
lm_scores
[
k
]
)
list_exp
.
append
(
new_hyp
)
if
not
list_exp
:
kept_hyps
=
sorted
(
list_b
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)[
:
beam
]
break
else
:
beam_state
=
self
.
decoder
.
create_batch_states
(
beam_state
,
[
hyp
.
dec_state
for
hyp
in
list_exp
],
[
hyp
.
yseq
for
hyp
in
list_exp
],
)
beam_dec_out
,
beam_state
,
beam_lm_tokens
=
self
.
decoder
.
batch_score
(
list_exp
,
beam_state
,
cache
,
self
.
use_lm
,
)
if
self
.
use_lm
:
beam_lm_states
=
create_lm_batch_states
(
[
hyp
.
lm_state
for
hyp
in
list_exp
],
self
.
lm_layers
,
self
.
is_wordlm
,
)
beam_lm_states
,
beam_lm_scores
=
self
.
lm
.
buff_predict
(
beam_lm_states
,
beam_lm_tokens
,
len
(
list_exp
)
)
if
n
<
(
self
.
nstep
-
1
):
for
i
,
hyp
in
enumerate
(
list_exp
):
hyp
.
dec_out
.
append
(
beam_dec_out
[
i
])
hyp
.
dec_state
=
self
.
decoder
.
select_state
(
beam_state
,
i
)
if
self
.
use_lm
:
hyp
.
lm_state
=
select_lm_state
(
beam_lm_states
,
i
,
self
.
lm_layers
,
self
.
is_wordlm
)
hyp
.
lm_scores
=
beam_lm_scores
[
i
]
hyps
=
list_exp
[:]
else
:
beam_logp
=
torch
.
log_softmax
(
self
.
joint_network
(
beam_enc_out
,
beam_dec_out
)
/
self
.
softmax_temperature
,
dim
=-
1
,
)
for
i
,
hyp
in
enumerate
(
list_exp
):
hyp
.
score
+=
float
(
beam_logp
[
i
,
0
])
hyp
.
dec_out
.
append
(
beam_dec_out
[
i
])
hyp
.
dec_state
=
self
.
decoder
.
select_state
(
beam_state
,
i
)
if
self
.
use_lm
:
hyp
.
lm_state
=
select_lm_state
(
beam_lm_states
,
i
,
self
.
lm_layers
,
self
.
is_wordlm
)
hyp
.
lm_scores
=
beam_lm_scores
[
i
]
kept_hyps
=
sorted
(
list_b
+
list_exp
,
key
=
lambda
x
:
x
.
score
,
reverse
=
True
)[:
beam
]
return
self
.
sort_nbest
(
kept_hyps
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/asr_interface.py
0 → 100644
View file @
60a2c57a
"""ASR Interface module."""
import
chainer
from
espnet.nets.asr_interface
import
ASRInterface
class
ChainerASRInterface
(
ASRInterface
,
chainer
.
Chain
):
"""ASR Interface for ESPnet model implementation."""
@
staticmethod
def
custom_converter
(
*
args
,
**
kw
):
"""Get customconverter of the model (Chainer only)."""
raise
NotImplementedError
(
"custom converter method is not implemented"
)
@
staticmethod
def
custom_updater
(
*
args
,
**
kw
):
"""Get custom_updater of the model (Chainer only)."""
raise
NotImplementedError
(
"custom updater method is not implemented"
)
@
staticmethod
def
custom_parallel_updater
(
*
args
,
**
kw
):
"""Get custom_parallel_updater of the model (Chainer only)."""
raise
NotImplementedError
(
"custom parallel updater method is not implemented"
)
def
get_total_subsampling_factor
(
self
):
"""Get total subsampling factor."""
raise
NotImplementedError
(
"get_total_subsampling_factor method is not implemented"
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/ctc.py
0 → 100644
View file @
60a2c57a
import
logging
import
chainer
import
chainer.functions
as
F
import
chainer.links
as
L
import
numpy
as
np
class
CTC
(
chainer
.
Chain
):
"""Chainer implementation of ctc layer.
Args:
odim (int): The output dimension.
eprojs (int | None): Dimension of input vectors from encoder.
dropout_rate (float): Dropout rate.
"""
def
__init__
(
self
,
odim
,
eprojs
,
dropout_rate
):
super
(
CTC
,
self
).
__init__
()
self
.
dropout_rate
=
dropout_rate
self
.
loss
=
None
with
self
.
init_scope
():
self
.
ctc_lo
=
L
.
Linear
(
eprojs
,
odim
)
def
__call__
(
self
,
hs
,
ys
):
"""CTC forward.
Args:
hs (list of chainer.Variable | N-dimension array):
Input variable from encoder.
ys (list of chainer.Variable | N-dimension array):
Input variable of decoder.
Returns:
chainer.Variable: A variable holding a scalar value of the CTC loss.
"""
self
.
loss
=
None
ilens
=
[
x
.
shape
[
0
]
for
x
in
hs
]
olens
=
[
x
.
shape
[
0
]
for
x
in
ys
]
# zero padding for hs
y_hat
=
self
.
ctc_lo
(
F
.
dropout
(
F
.
pad_sequence
(
hs
),
ratio
=
self
.
dropout_rate
),
n_batch_axes
=
2
)
y_hat
=
F
.
separate
(
y_hat
,
axis
=
1
)
# ilen list of batch x hdim
# zero padding for ys
y_true
=
F
.
pad_sequence
(
ys
,
padding
=-
1
)
# batch x olen
# get length info
input_length
=
chainer
.
Variable
(
self
.
xp
.
array
(
ilens
,
dtype
=
np
.
int32
))
label_length
=
chainer
.
Variable
(
self
.
xp
.
array
(
olens
,
dtype
=
np
.
int32
))
logging
.
info
(
self
.
__class__
.
__name__
+
" input lengths: "
+
str
(
input_length
.
data
)
)
logging
.
info
(
self
.
__class__
.
__name__
+
" output lengths: "
+
str
(
label_length
.
data
)
)
# get ctc loss
self
.
loss
=
F
.
connectionist_temporal_classification
(
y_hat
,
y_true
,
0
,
input_length
,
label_length
)
logging
.
info
(
"ctc loss:"
+
str
(
self
.
loss
.
data
))
return
self
.
loss
def
log_softmax
(
self
,
hs
):
"""Log_softmax of frame activations.
Args:
hs (list of chainer.Variable | N-dimension array):
Input variable from encoder.
Returns:
chainer.Variable: A n-dimension float array.
"""
y_hat
=
self
.
ctc_lo
(
F
.
pad_sequence
(
hs
),
n_batch_axes
=
2
)
return
F
.
log_softmax
(
y_hat
.
reshape
(
-
1
,
y_hat
.
shape
[
-
1
])).
reshape
(
y_hat
.
shape
)
def
ctc_for
(
args
,
odim
):
"""Return the CTC layer corresponding to the args.
Args:
args (Namespace): The program arguments.
odim (int): The output dimension.
Returns:
The CTC module.
"""
ctc_type
=
args
.
ctc_type
if
ctc_type
==
"builtin"
:
logging
.
info
(
"Using chainer CTC implementation"
)
ctc
=
CTC
(
odim
,
args
.
eprojs
,
args
.
dropout_rate
)
else
:
raise
ValueError
(
'ctc_type must be "builtin": {}'
.
format
(
ctc_type
))
return
ctc
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/chainer_backend/deterministic_embed_id.py
0 → 100644
View file @
60a2c57a
import
chainer
import
numpy
# from chainer.functions.connection import embed_id
from
chainer
import
cuda
,
function_node
,
link
,
variable
from
chainer.initializers
import
normal
from
chainer.utils
import
type_check
"""Deterministic EmbedID link and function
copied from chainer/links/connection/embed_id.py
and chainer/functions/connection/embed_id.py,
and modified not to use atomicAdd operation
"""
class
EmbedIDFunction
(
function_node
.
FunctionNode
):
def
__init__
(
self
,
ignore_label
=
None
):
self
.
ignore_label
=
ignore_label
self
.
_w_shape
=
None
def
check_type_forward
(
self
,
in_types
):
type_check
.
expect
(
in_types
.
size
()
==
2
)
x_type
,
w_type
=
in_types
type_check
.
expect
(
x_type
.
dtype
.
kind
==
"i"
,
x_type
.
ndim
>=
1
,
)
type_check
.
expect
(
w_type
.
dtype
==
numpy
.
float32
,
w_type
.
ndim
==
2
)
def
forward
(
self
,
inputs
):
self
.
retain_inputs
((
0
,))
x
,
W
=
inputs
self
.
_w_shape
=
W
.
shape
if
not
type_check
.
same_types
(
*
inputs
):
raise
ValueError
(
"numpy and cupy must not be used together
\n
"
"type(W): {0}, type(x): {1}"
.
format
(
type
(
W
),
type
(
x
))
)
xp
=
cuda
.
get_array_module
(
*
inputs
)
if
chainer
.
is_debug
():
valid_x
=
xp
.
logical_and
(
0
<=
x
,
x
<
len
(
W
))
if
self
.
ignore_label
is
not
None
:
valid_x
=
xp
.
logical_or
(
valid_x
,
x
==
self
.
ignore_label
)
if
not
valid_x
.
all
():
raise
ValueError
(
"Each not ignored `x` value need to satisfy"
"`0 <= x < len(W)`"
)
if
self
.
ignore_label
is
not
None
:
mask
=
x
==
self
.
ignore_label
return
(
xp
.
where
(
mask
[...,
None
],
0
,
W
[
xp
.
where
(
mask
,
0
,
x
)]),)
return
(
W
[
x
],)
def
backward
(
self
,
indexes
,
grad_outputs
):
inputs
=
self
.
get_retained_inputs
()
gW
=
EmbedIDGrad
(
self
.
_w_shape
,
self
.
ignore_label
).
apply
(
inputs
+
grad_outputs
)[
0
]
return
None
,
gW
class
EmbedIDGrad
(
function_node
.
FunctionNode
):
def
__init__
(
self
,
w_shape
,
ignore_label
=
None
):
self
.
w_shape
=
w_shape
self
.
ignore_label
=
ignore_label
self
.
_gy_shape
=
None
def
forward
(
self
,
inputs
):
self
.
retain_inputs
((
0
,))
xp
=
cuda
.
get_array_module
(
*
inputs
)
x
,
gy
=
inputs
self
.
_gy_shape
=
gy
.
shape
gW
=
xp
.
zeros
(
self
.
w_shape
,
dtype
=
gy
.
dtype
)
if
xp
is
numpy
:
# It is equivalent to `numpy.add.at(gW, x, gy)` but ufunc.at is
# too slow.
for
ix
,
igy
in
zip
(
x
.
ravel
(),
gy
.
reshape
(
x
.
size
,
-
1
)):
if
ix
==
self
.
ignore_label
:
continue
gW
[
ix
]
+=
igy
else
:
"""
# original code based on cuda elementwise method
if self.ignore_label is None:
cuda.elementwise(
'T gy, S x, S n_out', 'raw T gW',
'ptrdiff_t w_ind[] = {x, i % n_out};'
'atomicAdd(&gW[w_ind], gy)',
'embed_id_bwd')(
gy, xp.expand_dims(x, -1), gW.shape[1], gW)
else:
cuda.elementwise(
'T gy, S x, S n_out, S ignore', 'raw T gW',
'''
if (x != ignore) {
ptrdiff_t w_ind[] = {x, i % n_out};
atomicAdd(&gW[w_ind], gy);
}
''',
'embed_id_bwd_ignore_label')(
gy, xp.expand_dims(x, -1), gW.shape[1],
self.ignore_label, gW)
"""
# EmbedID gradient alternative without atomicAdd, which simply
# creates a one-hot vector and applies dot product
xi
=
xp
.
zeros
((
x
.
size
,
len
(
gW
)),
dtype
=
numpy
.
float32
)
idx
=
xp
.
arange
(
x
.
size
,
dtype
=
numpy
.
int32
)
*
len
(
gW
)
+
x
.
ravel
()
xi
.
ravel
()[
idx
]
=
1.0
if
self
.
ignore_label
is
not
None
:
xi
[:,
self
.
ignore_label
]
=
0.0
gW
=
xi
.
T
.
dot
(
gy
.
reshape
(
x
.
size
,
-
1
)).
astype
(
gW
.
dtype
,
copy
=
False
)
return
(
gW
,)
def
backward
(
self
,
indexes
,
grads
):
xp
=
cuda
.
get_array_module
(
*
grads
)
x
=
self
.
get_retained_inputs
()[
0
].
data
ggW
=
grads
[
0
]
if
self
.
ignore_label
is
not
None
:
mask
=
x
==
self
.
ignore_label
# To prevent index out of bounds, we need to check if ignore_label
# is inside of W.
if
not
(
0
<=
self
.
ignore_label
<
self
.
w_shape
[
1
]):
x
=
xp
.
where
(
mask
,
0
,
x
)
ggy
=
ggW
[
x
]
if
self
.
ignore_label
is
not
None
:
mask
,
zero
,
_
=
xp
.
broadcast_arrays
(
mask
[...,
None
],
xp
.
zeros
((),
"f"
),
ggy
.
data
)
ggy
=
chainer
.
functions
.
where
(
mask
,
zero
,
ggy
)
return
None
,
ggy
def
embed_id
(
x
,
W
,
ignore_label
=
None
):
r
"""Efficient linear function for one-hot input.
This function implements so called *word embeddings*. It takes two
arguments: a set of IDs (words) ``x`` in :math:`B` dimensional integer
vector, and a set of all ID (word) embeddings ``W`` in :math:`V \\times d`
float32 matrix. It outputs :math:`B \\times d` matrix whose ``i``-th
column is the ``x[i]``-th column of ``W``.
This function is only differentiable on the input ``W``.
Args:
x (chainer.Variable | np.ndarray): Batch vectors of IDs. Each
element must be signed integer.
W (chainer.Variable | np.ndarray): Distributed representation
of each ID (a.k.a. word embeddings).
ignore_label (int): If ignore_label is an int value, i-th column
of return value is filled with 0.
Returns:
chainer.Variable: Embedded variable.
.. rubric:: :class:`~chainer.links.EmbedID`
Examples:
>>> x = np.array([2, 1]).astype('i')
>>> x
array([2, 1], dtype=int32)
>>> W = np.array([[0, 0, 0],
... [1, 1, 1],
... [2, 2, 2]]).astype('f')
>>> W
array([[ 0., 0., 0.],
[ 1., 1., 1.],
[ 2., 2., 2.]], dtype=float32)
>>> F.embed_id(x, W).data
array([[ 2., 2., 2.],
[ 1., 1., 1.]], dtype=float32)
>>> F.embed_id(x, W, ignore_label=1).data
array([[ 2., 2., 2.],
[ 0., 0., 0.]], dtype=float32)
"""
return
EmbedIDFunction
(
ignore_label
=
ignore_label
).
apply
((
x
,
W
))[
0
]
class
EmbedID
(
link
.
Link
):
"""Efficient linear layer for one-hot input.
This is a link that wraps the :func:`~chainer.functions.embed_id` function.
This link holds the ID (word) embedding matrix ``W`` as a parameter.
Args:
in_size (int): Number of different identifiers (a.k.a. vocabulary size).
out_size (int): Output dimension.
initialW (Initializer): Initializer to initialize the weight.
ignore_label (int): If `ignore_label` is an int value, i-th column of
return value is filled with 0.
.. rubric:: :func:`~chainer.functions.embed_id`
Attributes:
W (~chainer.Variable): Embedding parameter matrix.
Examples:
>>> W = np.array([[0, 0, 0],
... [1, 1, 1],
... [2, 2, 2]]).astype('f')
>>> W
array([[ 0., 0., 0.],
[ 1., 1., 1.],
[ 2., 2., 2.]], dtype=float32)
>>> l = L.EmbedID(W.shape[0], W.shape[1], initialW=W)
>>> x = np.array([2, 1]).astype('i')
>>> x
array([2, 1], dtype=int32)
>>> y = l(x)
>>> y.data
array([[ 2., 2., 2.],
[ 1., 1., 1.]], dtype=float32)
"""
ignore_label
=
None
def
__init__
(
self
,
in_size
,
out_size
,
initialW
=
None
,
ignore_label
=
None
):
super
(
EmbedID
,
self
).
__init__
()
self
.
ignore_label
=
ignore_label
with
self
.
init_scope
():
if
initialW
is
None
:
initialW
=
normal
.
Normal
(
1.0
)
self
.
W
=
variable
.
Parameter
(
initialW
,
(
in_size
,
out_size
))
def
__call__
(
self
,
x
):
"""Extracts the word embedding of given IDs.
Args:
x (chainer.Variable): Batch vectors of IDs.
Returns:
chainer.Variable: Batch of corresponding embeddings.
"""
return
embed_id
(
x
,
self
.
W
,
ignore_label
=
self
.
ignore_label
)
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment