Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
sunzhq2
yidong-infer
Commits
0941998c
Commit
0941998c
authored
Feb 27, 2026
by
sunzhq2
Committed by
xuxo
Feb 27, 2026
Browse files
conformer add post and ana
parent
fde49a28
Changes
52
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1287 additions
and
44224 deletions
+1287
-44224
conformer/torch-infer/exp/lm_train_lm_transformer_char_batch_bins2000000/images/train_time.png
..._transformer_char_batch_bins2000000/images/train_time.png
+0
-0
conformer/torch-infer/exp/lm_train_lm_transformer_char_batch_bins2000000/perplexity_test/ppl
...lm_transformer_char_batch_bins2000000/perplexity_test/ppl
+1
-0
conformer/torch-infer/infer-compile.py
conformer/torch-infer/infer-compile.py
+0
-583
conformer/torch-infer/infer-torch.py
conformer/torch-infer/infer-torch.py
+7
-1
conformer/torch-infer/infer.py
conformer/torch-infer/infer.py
+344
-445
conformer/torch-infer/infer.sh
conformer/torch-infer/infer.sh
+72
-0
conformer/torch-infer/infer_io.py
conformer/torch-infer/infer_io.py
+103
-132
conformer/torch-infer/logs/wer_lm_rescoring_0
conformer/torch-infer/logs/wer_lm_rescoring_0
+0
-43063
conformer/torch-infer/meta.yaml
conformer/torch-infer/meta.yaml
+10
-0
conformer/torch-infer/post.sh
conformer/torch-infer/post.sh
+1
-0
espnet_model.py
espnet_model.py
+739
-0
meta.yaml
meta.yaml
+10
-0
No files found.
conformer/torch-infer/exp/lm_train_lm_transformer_char_batch_bins2000000/images/train_time.png
0 → 100644
View file @
0941998c
31.7 KB
conformer/torch-infer/exp/lm_train_lm_transformer_char_batch_bins2000000/perplexity_test/ppl
0 → 100644
View file @
0941998c
51.1541598159927
conformer/torch-infer/infer-compile.py
deleted
100644 → 0
View file @
fde49a28
import
torch
from
torch.utils.data
import
DataLoader
,
Dataset
import
soundfile
import
time
import
numpy
as
np
import
os
import
multiprocessing
import
argparse
from
typing
import
Dict
,
Optional
,
Tuple
from
espnet2.bin.asr_inference
import
Speech2Text
from
espnet2.torch_utils.device_funcs
import
to_device
torch
.
set_num_threads
(
1
)
try
:
from
swig_decoders
import
map_batch
,
\
ctc_beam_search_decoder_batch
,
\
TrieVector
,
PathTrie
except
ImportError
:
print
(
'Please install ctc decoders first by refering to
\n
'
+
'https://github.com/Slyne/ctc_decoder.git'
)
sys
.
exit
(
1
)
def
lm_batchify_nll
(
lm_scorer
,
text
:
torch
.
Tensor
,
text_lengths
:
torch
.
Tensor
,
batch_size
:
int
=
100
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute negative log likelihood(nll) from transformer language model using lm_scorer
To avoid OOM, this function separates the input into batches.
Then call batch_score for each batch and combine and return results.
Args:
lm_scorer: Language model scorer object
text: (Batch, Length)
text_lengths: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
"""
total_num
=
text
.
size
(
0
)
if
total_num
<=
batch_size
:
nll
,
x_lengths
=
_compute_nll_with_lm_scorer
(
lm_scorer
,
text
,
text_lengths
)
else
:
nlls
=
[]
x_lengths
=
[]
max_length
=
text_lengths
.
max
()
start_idx
=
0
while
True
:
end_idx
=
min
(
start_idx
+
batch_size
,
total_num
)
batch_text
=
text
[
start_idx
:
end_idx
,
:]
batch_text_lengths
=
text_lengths
[
start_idx
:
end_idx
]
# batch_nll: [B * T]
batch_nll
,
batch_x_lengths
=
_compute_nll_with_lm_scorer
(
lm_scorer
,
batch_text
,
batch_text_lengths
,
max_length
=
max_length
)
nlls
.
append
(
batch_nll
)
x_lengths
.
append
(
batch_x_lengths
)
start_idx
=
end_idx
if
start_idx
==
total_num
:
break
nll
=
torch
.
cat
(
nlls
)
x_lengths
=
torch
.
cat
(
x_lengths
)
assert
nll
.
size
(
0
)
==
total_num
assert
x_lengths
.
size
(
0
)
==
total_num
return
nll
,
x_lengths
def
_compute_nll_with_lm_scorer
(
lm_scorer
,
text
:
torch
.
Tensor
,
text_lengths
:
torch
.
Tensor
,
max_length
:
int
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute negative log likelihood using lm_scorer's score method
This function simulates the nll method using the available score method
from the lm_scorer object.
"""
batch_size
=
text
.
size
(
0
)
# For data parallel
if
max_length
is
None
:
text
=
text
[:,
:
text_lengths
.
max
()]
else
:
text
=
text
[:,
:
max_length
]
# Initialize nll for each sequence
nll
=
torch
.
zeros
(
batch_size
,
device
=
text
.
device
)
# Process each sequence individually
for
batch_idx
in
range
(
batch_size
):
seq_text
=
text
[
batch_idx
]
seq_length
=
text_lengths
[
batch_idx
]
# Truncate to actual sequence length
seq_text
=
seq_text
[:
seq_length
]
# Initialize state for this sequence
state
=
None
# Process each token position sequentially
for
pos
in
range
(
len
(
seq_text
)
-
1
):
# Get current token
current_token
=
seq_text
[
pos
].
unsqueeze
(
0
)
# shape: (1,)
# Score the current token
logp
,
state
=
lm_scorer
.
score
(
current_token
,
state
,
None
)
# Get the ground truth next token
next_token
=
seq_text
[
pos
+
1
]
# Get the negative log likelihood for the correct next token
token_nll
=
-
logp
[
next_token
]
nll
[
batch_idx
]
+=
token_nll
# x_lengths is text_lengths - 1 (since we score transitions between tokens)
x_lengths
=
text_lengths
-
1
x_lengths
=
torch
.
clamp
(
x_lengths
,
min
=
0
)
# Ensure non-negative
return
nll
,
x_lengths
class
CustomAishellDataset
(
Dataset
):
def
__init__
(
self
,
wav_scp_file
,
text_file
):
with
open
(
wav_scp_file
,
'r'
)
as
wav_scp
,
open
(
text_file
,
'r'
)
as
text
:
wavs
=
wav_scp
.
readlines
()
texts
=
text
.
readlines
()
self
.
wav_names
=
[
item
.
split
()[
0
]
for
item
in
wavs
]
self
.
wav_paths
=
[
item
.
split
()[
1
]
for
item
in
wavs
]
self
.
labels
=
[
""
.
join
(
item
.
split
()[
1
:])
for
item
in
texts
]
def
__len__
(
self
):
return
len
(
self
.
labels
)
def
__getitem__
(
self
,
idx
):
speech
,
sr
=
soundfile
.
read
(
self
.
wav_paths
[
idx
])
assert
sr
==
16000
,
sr
speech
=
np
.
array
(
speech
,
dtype
=
np
.
float32
)
speech_len
=
speech
.
shape
[
0
]
label
=
self
.
labels
[
idx
]
name
=
self
.
wav_names
[
idx
]
return
speech
,
speech_len
,
label
,
name
def
collate_wrapper
(
batch
):
speeches
=
np
.
zeros
((
len
(
batch
),
16000
*
30
),
dtype
=
np
.
float32
)
lengths
=
np
.
zeros
(
len
(
batch
),
dtype
=
np
.
int64
)
labels
=
[]
names
=
[]
for
i
,
(
speech
,
speech_len
,
label
,
name
)
in
enumerate
(
batch
):
speeches
[
i
,:
speech_len
]
=
speech
lengths
[
i
]
=
speech_len
labels
.
append
(
label
)
names
.
append
(
name
)
speeches
=
speeches
[:,:
max
(
lengths
)]
return
speeches
,
lengths
,
labels
,
names
def
make_pad_mask
(
lengths
:
torch
.
Tensor
,
max_len
:
int
=
0
)
->
torch
.
Tensor
:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size
=
lengths
.
size
(
0
)
max_len
=
max_len
if
max_len
>
0
else
lengths
.
max
().
item
()
seq_range
=
torch
.
arange
(
0
,
max_len
,
dtype
=
torch
.
int64
,
device
=
lengths
.
device
)
seq_range_expand
=
seq_range
.
unsqueeze
(
0
).
expand
(
batch_size
,
max_len
)
seq_length_expand
=
lengths
.
unsqueeze
(
-
1
)
mask
=
seq_range_expand
>=
seq_length_expand
return
mask
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'recognize with your model'
)
parser
.
add_argument
(
'--config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--lm_config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--gpu'
,
type
=
int
,
default
=
0
,
help
=
'gpu id for this rank, -1 for cpu'
)
parser
.
add_argument
(
'--wav_scp'
,
required
=
True
,
help
=
'wav scp file'
)
parser
.
add_argument
(
'--text'
,
required
=
True
,
help
=
'ground truth text file'
)
parser
.
add_argument
(
'--model_path'
,
required
=
True
,
help
=
'torch pt model file'
)
parser
.
add_argument
(
'--lm_path'
,
required
=
True
,
help
=
'torch pt model file'
)
parser
.
add_argument
(
'--result_file'
,
default
=
'./predictions.txt'
,
help
=
'asr result file'
)
parser
.
add_argument
(
'--log_file'
,
default
=
'./rtf.txt'
,
help
=
'asr decoding log'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
24
,
help
=
'batch_size'
)
parser
.
add_argument
(
'--beam_size'
,
type
=
int
,
default
=
10
,
help
=
'beam_size'
)
parser
.
add_argument
(
'--mode'
,
choices
=
[
'ctc_greedy_search'
,
'ctc_prefix_beam_search'
,
'attention_rescoring'
,
'attention_lm_rescoring'
,
'lm_rescoring'
],
default
=
'attention_lm_rescoring'
,
help
=
'decoding mode'
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
'__main__'
:
args
=
get_args
()
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu
)
dataset
=
CustomAishellDataset
(
args
.
wav_scp
,
args
.
text
)
test_data_loader
=
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
collate_fn
=
collate_wrapper
)
speech2text
=
Speech2Text
(
args
.
config
,
args
.
model_path
,
None
,
args
.
lm_config
,
args
.
lm_path
,
device
=
"cuda"
)
# 手动加载完整的ESPnetLanguageModel对象
# 因为Speech2Text中只存储了原始语言模型,我们需要完整的对象来使用batchify_nll方法
full_lm_model
=
None
if
args
.
lm_config
is
not
None
and
args
.
lm_path
is
not
None
:
from
espnet2.tasks.lm
import
LMTask
full_lm_model
,
_
=
LMTask
.
build_model_from_file
(
args
.
lm_config
,
args
.
lm_path
,
"cuda"
)
full_lm_model
.
eval
()
# 使用torch.compile优化模型性能
# 检查PyTorch版本是否支持torch.compile
if
hasattr
(
torch
,
'compile'
)
and
torch
.
cuda
.
is_available
():
print
(
"启用torch.compile优化..."
)
# 尝试不同的后端,从最兼容到最高性能
backends_to_try
=
[
(
"aot_eager"
,
{}),
# aot_eager不支持mode参数
(
"eager"
,
{
"mode"
:
"reduce-overhead"
}),
(
"inductor"
,
{
"mode"
:
"reduce-overhead"
,
"dynamic"
:
False
,
"fullgraph"
:
False
})
]
for
backend_name
,
backend_options
in
backends_to_try
:
try
:
print
(
f
"尝试使用
{
backend_name
}
后端进行编译..."
)
# 编译ASR模型的关键组件
if
hasattr
(
speech2text
.
asr_model
,
'encode'
):
speech2text
.
asr_model
.
encode
=
torch
.
compile
(
speech2text
.
asr_model
.
encode
,
backend
=
backend_name
,
**
backend_options
)
if
hasattr
(
speech2text
.
asr_model
.
ctc
,
'ctc_lo'
):
speech2text
.
asr_model
.
ctc
.
ctc_lo
=
torch
.
compile
(
speech2text
.
asr_model
.
ctc
.
ctc_lo
,
backend
=
backend_name
,
**
backend_options
)
# 编译语言模型(如果存在)
if
full_lm_model
is
not
None
and
hasattr
(
full_lm_model
,
'batchify_nll'
):
full_lm_model
.
batchify_nll
=
torch
.
compile
(
full_lm_model
.
batchify_nll
,
backend
=
backend_name
,
**
backend_options
)
# 编译成功,设置TensorFloat-32加速
torch
.
set_float32_matmul_precision
(
'high'
)
print
(
f
"✓ 使用
{
backend_name
}
后端编译成功"
)
print
(
"✓ TensorFloat-32加速已启用"
)
break
except
Exception
as
e
:
print
(
f
"⚠
{
backend_name
}
后端编译失败:
{
e
}
"
)
# 恢复原始函数
if
hasattr
(
speech2text
.
asr_model
,
'encode'
):
speech2text
.
asr_model
.
encode
=
speech2text
.
asr_model
.
encode
.
_orig_mod
if
hasattr
(
speech2text
.
asr_model
.
encode
,
'_orig_mod'
)
else
speech2text
.
asr_model
.
encode
if
hasattr
(
speech2text
.
asr_model
.
ctc
,
'ctc_lo'
):
speech2text
.
asr_model
.
ctc
.
ctc_lo
=
speech2text
.
asr_model
.
ctc
.
ctc_lo
.
_orig_mod
if
hasattr
(
speech2text
.
asr_model
.
ctc
.
ctc_lo
,
'_orig_mod'
)
else
speech2text
.
asr_model
.
ctc
.
ctc_lo
if
full_lm_model
is
not
None
and
hasattr
(
full_lm_model
,
'batchify_nll'
):
full_lm_model
.
batchify_nll
=
full_lm_model
.
batchify_nll
.
_orig_mod
if
hasattr
(
full_lm_model
.
batchify_nll
,
'_orig_mod'
)
else
full_lm_model
.
batchify_nll
if
backend_name
==
backends_to_try
[
-
1
][
0
]:
# 所有后端都失败
print
(
"⚠ 所有编译后端都失败,将使用未编译模式运行"
)
torch
.
set_float32_matmul_precision
(
'high'
)
# 仍然启用TF32加速
print
(
"✓ TensorFloat-32加速已启用(未编译模式)"
)
audio_sample_len
=
0
total_inference_time
=
0
with
torch
.
no_grad
(),
open
(
args
.
result_file
,
'w'
)
as
fout
:
for
_
,
batch
in
enumerate
(
test_data_loader
):
# 开始计时推理时间(不包含torch.compile时间)
batch_start_time
=
time
.
perf_counter
()
speech
,
speech_lens
,
labels
,
names
=
batch
audio_sample_len
+=
np
.
sum
(
speech_lens
)
/
16000
batch
=
{
"speech"
:
speech
,
"speech_lengths"
:
speech_lens
}
if
isinstance
(
batch
[
"speech"
],
np
.
ndarray
):
batch
[
"speech"
]
=
torch
.
tensor
(
batch
[
"speech"
])
if
isinstance
(
batch
[
"speech_lengths"
],
np
.
ndarray
):
batch
[
"speech_lengths"
]
=
torch
.
tensor
(
batch
[
"speech_lengths"
])
# a. To device
batch
=
to_device
(
batch
,
device
=
'cuda'
)
# b. Forward Encoder
# enc: [N, T, C]
ll
=
time
.
time
()
encoder_out
,
encoder_out_lens
=
speech2text
.
asr_model
.
encode
(
**
batch
)
# ctc_log_probs: [N, T, C]
ctc_logits
=
speech2text
.
asr_model
.
ctc
.
ctc_lo
(
encoder_out
)
ctc_log_probs
=
torch
.
nn
.
functional
.
log_softmax
(
ctc_logits
,
dim
=
2
)
beam_log_probs
,
beam_log_probs_idx
=
torch
.
topk
(
ctc_log_probs
,
args
.
beam_size
,
dim
=
2
)
num_processes
=
min
(
multiprocessing
.
cpu_count
(),
args
.
batch_size
)
if
args
.
mode
==
'ctc_greedy_search'
:
assert
args
.
beam_size
!=
1
log_probs_idx
=
beam_log_probs_idx
[:,
:,
0
]
batch_sents
=
[]
for
idx
,
seq
in
enumerate
(
log_probs_idx
):
batch_sents
.
append
(
seq
[
0
:
encoder_out_lens
[
idx
]].
tolist
())
hyps
=
map_batch
(
batch_sents
,
speech2text
.
asr_model
.
token_list
,
num_processes
,
True
,
0
)
else
:
batch_log_probs_seq_list
=
beam_log_probs
.
tolist
()
batch_log_probs_idx_list
=
beam_log_probs_idx
.
tolist
()
batch_len_list
=
encoder_out_lens
.
tolist
()
batch_log_probs_seq
=
[]
batch_log_probs_ids
=
[]
batch_start
=
[]
# only effective in streaming deployment
batch_root
=
TrieVector
()
root_dict
=
{}
for
i
in
range
(
len
(
batch_len_list
)):
num_sent
=
batch_len_list
[
i
]
batch_log_probs_seq
.
append
(
batch_log_probs_seq_list
[
i
][
0
:
num_sent
])
batch_log_probs_ids
.
append
(
batch_log_probs_idx_list
[
i
][
0
:
num_sent
])
root_dict
[
i
]
=
PathTrie
()
batch_root
.
append
(
root_dict
[
i
])
batch_start
.
append
(
True
)
score_hyps
=
ctc_beam_search_decoder_batch
(
batch_log_probs_seq
,
batch_log_probs_ids
,
batch_root
,
batch_start
,
args
.
beam_size
,
num_processes
,
0
,
-
2
,
0.99999
)
if
args
.
mode
==
'ctc_prefix_beam_search'
:
hyps
=
[]
for
cand_hyps
in
score_hyps
:
hyps
.
append
(
cand_hyps
[
0
][
1
])
hyps
=
map_batch
(
hyps
,
speech2text
.
asr_model
.
token_list
,
num_processes
,
False
,
0
)
elif
args
.
mode
==
'attention_rescoring'
:
ctc_score
,
all_hyps
=
[],
[]
max_len
=
0
for
hyps
in
score_hyps
:
cur_len
=
len
(
hyps
)
if
len
(
hyps
)
<
args
.
beam_size
:
hyps
+=
(
args
.
beam_size
-
cur_len
)
*
[(
-
float
(
"INF"
),
(
0
,))]
cur_ctc_score
=
[]
for
hyp
in
hyps
:
cur_ctc_score
.
append
(
hyp
[
0
])
all_hyps
.
append
(
list
(
hyp
[
1
]))
if
len
(
hyp
[
1
])
>
max_len
:
max_len
=
len
(
hyp
[
1
])
ctc_score
.
append
(
cur_ctc_score
)
ctc_score
=
torch
.
tensor
(
ctc_score
,
dtype
=
torch
.
float32
)
hyps_pad_sos_eos
=
torch
.
ones
(
(
args
.
batch_size
,
args
.
beam_size
,
max_len
+
2
),
dtype
=
torch
.
int64
)
*
speech2text
.
asr_model
.
ignore_id
# FIXME: ignore id
hyps_pad_sos
=
torch
.
ones
(
(
args
.
batch_size
,
args
.
beam_size
,
max_len
+
1
),
dtype
=
torch
.
int64
)
*
speech2text
.
asr_model
.
eos
# FIXME: eos
hyps_pad_eos
=
torch
.
ones
(
(
args
.
batch_size
,
args
.
beam_size
,
max_len
+
1
),
dtype
=
torch
.
int64
)
*
speech2text
.
asr_model
.
ignore_id
# FIXME: ignore id
hyps_lens_sos
=
torch
.
ones
((
args
.
batch_size
,
args
.
beam_size
),
dtype
=
torch
.
int32
)
k
=
0
for
i
in
range
(
args
.
batch_size
):
for
j
in
range
(
args
.
beam_size
):
cand
=
all_hyps
[
k
]
l
=
len
(
cand
)
+
2
hyps_pad_sos_eos
[
i
][
j
][
0
:
l
]
=
torch
.
tensor
([
speech2text
.
asr_model
.
sos
]
+
cand
+
[
speech2text
.
asr_model
.
eos
])
hyps_pad_sos
[
i
][
j
][
0
:
l
-
1
]
=
torch
.
tensor
([
speech2text
.
asr_model
.
sos
]
+
cand
)
hyps_pad_eos
[
i
][
j
][
0
:
l
-
1
]
=
torch
.
tensor
(
cand
+
[
speech2text
.
asr_model
.
eos
])
hyps_lens_sos
[
i
][
j
]
=
len
(
cand
)
+
1
k
+=
1
bz
=
args
.
beam_size
B
,
T
,
F
=
encoder_out
.
shape
B2
=
B
*
bz
encoder_out
=
encoder_out
.
repeat
(
1
,
bz
,
1
).
view
(
B2
,
T
,
F
)
encoder_out_lens
=
encoder_out_lens
.
repeat
(
bz
)
hyps_pad
=
hyps_pad_sos_eos
.
view
(
B2
,
max_len
+
2
)
hyps_lens
=
hyps_lens_sos
.
view
(
B2
,)
hyps_pad_sos
=
hyps_pad_sos
.
view
(
B2
,
max_len
+
1
)
hyps_pad_eos
=
hyps_pad_eos
.
view
(
B2
,
max_len
+
1
)
#hyps_pad_sos = hyps_pad[:, :-1]
#hyps_pad_eos = hyps_pad[:, 1:]
decoder_out
,
_
=
speech2text
.
asr_model
.
decoder
(
encoder_out
,
encoder_out_lens
,
hyps_pad_sos
.
cuda
(),
hyps_lens
.
cuda
())
decoder_out
=
torch
.
nn
.
functional
.
log_softmax
(
decoder_out
,
dim
=-
1
)
mask
=
~
make_pad_mask
(
hyps_lens
,
max_len
+
1
)
# B2 x T2
# mask index, remove ignore id
index
=
torch
.
unsqueeze
(
hyps_pad_eos
*
mask
,
2
)
score
=
decoder_out
.
cpu
().
gather
(
2
,
index
).
squeeze
(
2
)
# B2 X T2
# mask padded part
score
=
score
*
mask
# decoder_out = decoder_out.view(B, bz, max_len+1, -1)
score
=
torch
.
sum
(
score
,
axis
=
1
)
score
=
torch
.
reshape
(
score
,(
B
,
bz
))
all_scores
=
ctc_score
+
0.1
*
score
# FIX ME need tuned
best_index
=
torch
.
argmax
(
all_scores
,
dim
=
1
)
best_sents
=
[]
k
=
0
for
idx
in
best_index
:
cur_best_sent
=
all_hyps
[
k
:
k
+
args
.
beam_size
][
idx
]
best_sents
.
append
(
cur_best_sent
)
k
+=
args
.
beam_size
hyps
=
map_batch
(
best_sents
,
speech2text
.
asr_model
.
token_list
,
num_processes
)
elif
args
.
mode
==
'attention_lm_rescoring'
:
ctc_score
,
all_hyps
=
[],
[]
max_len
=
0
for
hyps
in
score_hyps
:
cur_len
=
len
(
hyps
)
if
len
(
hyps
)
<
args
.
beam_size
:
hyps
+=
(
args
.
beam_size
-
cur_len
)
*
[(
-
float
(
"INF"
),
(
0
,))]
cur_ctc_score
=
[]
for
hyp
in
hyps
:
cur_ctc_score
.
append
(
hyp
[
0
])
all_hyps
.
append
(
list
(
hyp
[
1
]))
if
len
(
hyp
[
1
])
>
max_len
:
max_len
=
len
(
hyp
[
1
])
ctc_score
.
append
(
cur_ctc_score
)
ctc_score
=
torch
.
tensor
(
ctc_score
,
dtype
=
torch
.
float32
)
# 优化:批量构建hyps_pad,避免嵌套循环
hyps_pad
=
torch
.
full
((
args
.
batch_size
,
args
.
beam_size
,
max_len
),
speech2text
.
asr_model
.
ignore_id
,
dtype
=
torch
.
int64
)
hyps_lens
=
torch
.
zeros
((
args
.
batch_size
,
args
.
beam_size
),
dtype
=
torch
.
int32
)
# 批量填充数据
for
k
,
cand
in
enumerate
(
all_hyps
):
i
=
k
//
args
.
beam_size
j
=
k
%
args
.
beam_size
l
=
len
(
cand
)
hyps_pad
[
i
,
j
,
:
l
]
=
torch
.
tensor
(
cand
,
dtype
=
torch
.
int64
)
hyps_lens
[
i
,
j
]
=
l
bz
=
args
.
beam_size
B
,
T
,
F
=
encoder_out
.
shape
B2
=
B
*
bz
encoder_out
=
encoder_out
.
repeat
(
1
,
bz
,
1
).
view
(
B2
,
T
,
F
)
encoder_out_lens
=
encoder_out_lens
.
repeat
(
bz
)
hyps_pad
=
hyps_pad
.
view
(
B2
,
max_len
).
cuda
()
hyps_lens
=
hyps_lens
.
view
(
B2
,).
cuda
()
decoder_scores
=
-
speech2text
.
asr_model
.
batchify_nll
(
encoder_out
,
encoder_out_lens
,
hyps_pad
,
hyps_lens
,
320
)
decoder_scores
=
torch
.
reshape
(
decoder_scores
,(
B
,
bz
)).
cpu
()
# 使用完整的ESPnetLanguageModel对象进行语言模型评分
if
full_lm_model
is
not
None
:
try
:
# 首先清理数据:将ignore_id替换为0(语言模型的padding值)
hyps_pad_clean
=
hyps_pad
.
clone
()
hyps_pad_clean
[
hyps_pad_clean
==
speech2text
.
asr_model
.
ignore_id
]
=
0
# 使用更小的批量大小避免内存问题
nnlm_nll
,
x_lengths
=
full_lm_model
.
batchify_nll
(
hyps_pad_clean
,
hyps_lens
,
64
)
except
Exception
as
e
:
print
(
f
"语言模型评分失败:
{
e
}
"
)
# 如果失败,使用零值作为fallback
nnlm_nll
=
torch
.
zeros_like
(
hyps_pad
)
x_lengths
=
hyps_lens
else
:
# 如果没有语言模型,使用默认值
nnlm_nll
=
torch
.
zeros_like
(
hyps_pad
)
x_lengths
=
hyps_lens
nnlm_scores
=
-
nnlm_nll
.
sum
(
dim
=
1
)
nnlm_scores
=
torch
.
reshape
(
nnlm_scores
,(
B
,
bz
)).
cpu
()
all_scores
=
ctc_score
-
0.05
*
decoder_scores
+
1.0
*
nnlm_scores
# FIX ME need tuned
best_index
=
torch
.
argmax
(
all_scores
,
dim
=
1
)
best_sents
=
[]
k
=
0
for
idx
in
best_index
:
cur_best_sent
=
all_hyps
[
k
:
k
+
args
.
beam_size
][
idx
]
best_sents
.
append
(
cur_best_sent
)
k
+=
args
.
beam_size
hyps
=
map_batch
(
best_sents
,
speech2text
.
asr_model
.
token_list
,
num_processes
)
elif
args
.
mode
==
'lm_rescoring'
:
# 优化:预分配内存,避免动态扩展
ctc_score
=
[]
all_hyps
=
[]
max_len
=
0
# 预计算最大长度
for
hyps
in
score_hyps
:
for
hyp
in
hyps
:
if
len
(
hyp
[
1
])
>
max_len
:
max_len
=
len
(
hyp
[
1
])
# 批量处理
for
hyps
in
score_hyps
:
cur_len
=
len
(
hyps
)
if
len
(
hyps
)
<
args
.
beam_size
:
hyps
+=
(
args
.
beam_size
-
cur_len
)
*
[(
-
float
(
"INF"
),
(
0
,))]
cur_ctc_score
=
[]
for
hyp
in
hyps
:
cur_ctc_score
.
append
(
hyp
[
0
])
all_hyps
.
append
(
list
(
hyp
[
1
]))
ctc_score
.
append
(
cur_ctc_score
)
ctc_score
=
torch
.
tensor
(
ctc_score
,
dtype
=
torch
.
float32
)
hyps_pad
=
torch
.
ones
(
(
args
.
batch_size
,
args
.
beam_size
,
max_len
),
dtype
=
torch
.
int64
)
*
speech2text
.
asr_model
.
ignore_id
# FIXME: ignore id
hyps_lens
=
torch
.
ones
((
args
.
batch_size
,
args
.
beam_size
),
dtype
=
torch
.
int32
)
k
=
0
for
i
in
range
(
args
.
batch_size
):
for
j
in
range
(
args
.
beam_size
):
cand
=
all_hyps
[
k
]
l
=
len
(
cand
)
hyps_pad
[
i
][
j
][
0
:
l
]
=
torch
.
tensor
(
cand
)
hyps_lens
[
i
][
j
]
=
len
(
cand
)
k
+=
1
bz
=
args
.
beam_size
B
,
T
,
F
=
encoder_out
.
shape
B2
=
B
*
bz
hyps_pad
=
hyps_pad
.
view
(
B2
,
max_len
).
cuda
()
hyps_lens
=
hyps_lens
.
view
(
B2
,).
cuda
()
hyps_pad
[
hyps_pad
==
speech2text
.
asr_model
.
ignore_id
]
=
0
nnlm_nll
,
x_lengths
=
full_lm_model
.
batchify_nll
(
hyps_pad
,
hyps_lens
,
320
)
nnlm_scores
=
-
nnlm_nll
.
sum
(
dim
=
1
)
nnlm_scores
=
torch
.
reshape
(
nnlm_scores
,(
B
,
bz
))
# 直接在GPU上计算,避免CPU-GPU传输
ctc_score_gpu
=
ctc_score
.
cuda
()
all_scores
=
ctc_score_gpu
+
0.9
*
nnlm_scores
# FIX ME need tuned
best_index
=
torch
.
argmax
(
all_scores
,
dim
=
1
)
best_index
=
best_index
.
cpu
()
# 只在最后传输到CPU
best_sents
=
[]
k
=
0
for
idx
in
best_index
:
cur_best_sent
=
all_hyps
[
k
:
k
+
args
.
beam_size
][
idx
]
best_sents
.
append
(
cur_best_sent
)
k
+=
args
.
beam_size
hyps
=
map_batch
(
best_sents
,
speech2text
.
asr_model
.
token_list
,
num_processes
)
else
:
raise
NotImplementedError
print
(
"耗时:"
,{
time
.
time
()
-
ll
},
"fps:"
,
{
24
/
(
time
.
time
()
-
ll
)})
for
i
,
key
in
enumerate
(
names
):
content
=
hyps
[
i
]
# print('{} {}'.format(key, content))
fout
.
write
(
'{} {}
\n
'
.
format
(
key
,
content
))
# 记录batch推理时间(不包含torch.compile时间)
batch_end_time
=
time
.
perf_counter
()
total_inference_time
+=
batch_end_time
-
batch_start_time
# 计算总时间统计(不包含torch.compile时间)
if
str
(
args
.
gpu
)
==
'0'
:
with
open
(
args
.
log_file
,
'w'
)
as
log
:
log
.
write
(
f
"Decoding audio
{
audio_sample_len
}
secs, cost
{
total_inference_time
}
secs (不包含torch.compile时间), RTF:
{
total_inference_time
/
audio_sample_len
}
, process
{
audio_sample_len
/
total_inference_time
}
secs audio per second, decoding args:
{
args
}
"
)
conformer/torch-infer/
1
.py
→
conformer/torch-infer/
infer-torch
.py
View file @
0941998c
...
...
@@ -158,8 +158,14 @@ if __name__ == '__main__':
# b. Forward Encoder
# enc: [N, T, C]
feats
,
feats_lengths
=
speech2text
.
asr_model
.
pre_data
(
**
batch
)
feats_lengths_1
=
torch
.
ceil
(
feats_lengths
.
float
()
/
4
).
long
()
print
(
"feats_lengths_1:"
,
feats_lengths_1
)
# print("feats_lengths:",feats_lengths)
ll_time
=
time
.
time
()
encoder_out
,
encoder_out_lens
=
speech2text
.
asr_model
.
encode
(
**
batch
)
encoder_out
,
encoder_out_lens
=
speech2text
.
asr_model
.
encode
(
feats
,
feats_lengths
)
print
(
"encoder_out_lens:"
,
encoder_out_lens
)
# ctc_log_probs: [N, T, C]
ctc_log_probs
=
torch
.
nn
.
functional
.
log_softmax
(
speech2text
.
asr_model
.
ctc
.
ctc_lo
(
encoder_out
),
dim
=
2
...
...
conformer/torch-infer/infer.py
View file @
0941998c
#!/usr/bin/env python3
import
torch
from
torch.utils.data
import
DataLoader
,
Dataset
import
soundfile
...
...
@@ -59,94 +60,6 @@ def collate_wrapper(batch):
return
speeches
,
lengths
,
labels
,
names
# def collate_wrapper(batch):
# """
# 实现与ESPNet模型相同的特征处理流程:
# 1. 提取特征(相当于 self._extract_feats)
# 2. 跳过数据增强(仅在训练时使用)
# 3. 特征归一化(相当于 self.normalize)
# """
# speeches = np.zeros((len(batch), 16000 * 30), dtype=np.float32)
# lengths = np.zeros(len(batch), dtype=np.int64)
# labels = []
# names = []
# for i, (speech, speech_len, label, name) in enumerate(batch):
# speeches[i, :speech_len] = speech
# lengths[i] = speech_len
# labels.append(label)
# names.append(name)
# speeches = speeches[:, :max(lengths)]
# try:
# # === 1. 提取特征(相当于 self._extract_feats) ===
# import librosa
# batch_size = speeches.shape[0]
# features_list = []
# for i in range(batch_size):
# audio = speeches[i]
# # 提取梅尔特征(与ESPNet前端处理一致)
# audio = librosa.effects.trim(audio, top_db=20)[0] # 去除静音
# stft = librosa.stft(audio, n_fft=512, hop_length=128, win_length=512)
# spectrogram = np.abs(stft)
# mel_filter = librosa.filters.mel(sr=16000, n_fft=512, n_mels=80)
# mel_spectrogram = np.dot(mel_filter, spectrogram)
# log_mel_spectrogram = np.log(np.clip(mel_spectrogram, a_min=1e-10, a_max=None))
# log_mel_spectrogram = log_mel_spectrogram.T # [time, 80]
# features_list.append(log_mel_spectrogram)
# # 找到最大时间长度并填充
# max_time = max(feat.shape[0] for feat in features_list)
# features = np.zeros((batch_size, max_time, 80), dtype=np.float32)
# for i, feat in enumerate(features_list):
# features[i, :feat.shape[0], :] = feat
# feats_lengths = np.array([feat.shape[0] for feat in features_list], dtype=np.int64)
# # print(f"特征提取完成: 音频形状 {speeches.shape} -> 特征形状 {features.shape}")
# # === 2. 跳过数据增强(仅在训练时使用) ===
# # if self.specaug is not None and self.training: # 跳过
# # feats, feats_lengths = self.specaug(feats, feats_lengths)
# # === 3. 特征归一化(相当于 self.normalize) ===
# stats_file = "/home/sunzhq/workspace/yidong-infer/conformer/34e9cabc2c29fd0e3a2917ffa525d98b/exp/asr_stats_raw_sp/train/feats_stats.npz"
# # 导入GlobalMVN类
# from espnet2.layers.global_mvn import GlobalMVN
# # 创建GlobalMVN实例(与ESPNet配置相同)
# global_mvn = GlobalMVN(
# stats_file=stats_file,
# norm_means=True,
# norm_vars=True
# )
# # 转换为PyTorch张量并应用GlobalMVN
# features_tensor = torch.from_numpy(features).float()
# feats_lengths_tensor = torch.from_numpy(feats_lengths).long()
# # 应用GlobalMVN归一化
# normalized_features, normalized_lengths = global_mvn(features_tensor, feats_lengths_tensor)
# # 转换回numpy
# features = normalized_features.numpy()
# feats_lengths = normalized_lengths.numpy()
# # print(f"特征归一化完成: 使用GlobalMVN,统计文件 {stats_file}")
# # 返回处理后的特征
# return features, feats_lengths, labels, names
# except Exception as e:
# print(f"特征处理失败: {e}")
# print("将返回原始音频数据")
# return speeches, lengths, labels, names
def
make_pad_mask
(
lengths
:
torch
.
Tensor
,
max_len
:
int
=
0
)
->
torch
.
Tensor
:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
...
...
@@ -172,147 +85,33 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
mask
=
seq_range_expand
>=
seq_length_expand
return
mask
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'recognize with your model'
)
parser
.
add_argument
(
'--config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--lm_config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--gpu'
,
type
=
int
,
default
=
0
,
help
=
'gpu id for this rank, -1 for cpu'
)
parser
.
add_argument
(
'--wav_scp'
,
required
=
True
,
help
=
'wav scp file'
)
parser
.
add_argument
(
'--text'
,
required
=
True
,
help
=
'ground truth text file'
)
parser
.
add_argument
(
'--model_path'
,
required
=
True
,
help
=
'torch pt model file'
)
parser
.
add_argument
(
'--lm_path'
,
required
=
True
,
help
=
'torch pt model file'
)
parser
.
add_argument
(
'--result_file'
,
default
=
'./predictions.txt'
,
help
=
'asr result file'
)
parser
.
add_argument
(
'--log_file'
,
default
=
'./rtf.txt'
,
help
=
'asr decoding log'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
24
,
help
=
'batch_size'
)
parser
.
add_argument
(
'--beam_size'
,
type
=
int
,
default
=
10
,
help
=
'beam_size'
)
parser
.
add_argument
(
'--mode'
,
choices
=
[
'ctc_greedy_search'
,
'ctc_prefix_beam_search'
,
'attention_rescoring'
,
'attention_lm_rescoring'
,
'lm_rescoring'
],
default
=
'attention_lm_rescoring'
,
help
=
'decoding mode'
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
'__main__'
:
args
=
get_args
()
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu
)
dataset
=
CustomAishellDataset
(
args
.
wav_scp
,
args
.
text
)
# test_data_loader = DataLoader(dataset, batch_size=args.batch_size,
# collate_fn=collate_wrapper)
test_data_loader
=
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
collate_fn
=
collate_wrapper
)
speech2text
=
Speech2Text
(
args
.
config
,
args
.
model_path
,
None
,
args
.
lm_config
,
args
.
lm_path
,
device
=
"cuda"
)
# 手动加载完整的ESPnetLanguageModel对象
# 因为Speech2Text中只存储了原始语言模型,我们需要完整的对象来使用batchify_nll方法
full_lm_model
=
None
if
args
.
lm_config
is
not
None
and
args
.
lm_path
is
not
None
:
from
espnet2.tasks.lm
import
LMTask
full_lm_model
,
_
=
LMTask
.
build_model_from_file
(
args
.
lm_config
,
args
.
lm_path
,
"cuda"
)
full_lm_model
.
eval
()
import
onnxruntime
as
ort
sess_options
=
ort
.
SessionOptions
()
sess_options
.
graph_optimization_level
=
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_ALL
sess_options
.
enable_cpu_mem_arena
=
False
sess_options
.
enable_mem_pattern
=
False
providers
=
[
'ROCMExecutionProvider'
]
encoder_path
=
"/home/sunzhq/workspace/yidong-infer/conformer/onnx_models_batch24/transformer_lm/full/default_encoder_fp16.onnx"
encoder_session
=
ort
.
InferenceSession
(
encoder_path
,
providers
=
providers
)
# encoder_session_io = encoder_session.io_binding()
output_names
=
[
"encoder_out"
,
"encode_out_lens"
]
time_start
=
time
.
perf_counter
()
audio_sample_len
=
0
encoder_times
=
[]
ctc_times
=
[]
decoder_times
=
[]
lm_times
=
[]
beam_search_times
=
[]
count_times
=
[]
with
torch
.
no_grad
(),
open
(
args
.
result_file
,
'w'
)
as
fout
:
for
_
,
batch
in
enumerate
(
test_data_loader
):
def
process_batch_data
(
batch
,
speech2text
):
"""Process batch data and prepare for inference"""
speech
,
speech_lens
,
labels
,
names
=
batch
audio_sample_len
+=
np
.
sum
(
speech_lens
)
/
16000
batch
=
{
"speech"
:
speech
,
"speech_lengths"
:
speech_lens
}
if
isinstance
(
batch
[
"speech"
],
np
.
ndarray
):
batch
[
"speech"
]
=
torch
.
tensor
(
batch
[
"speech"
])
if
isinstance
(
batch
[
"speech_lengths"
],
np
.
ndarray
):
batch
[
"speech_lengths"
]
=
torch
.
tensor
(
batch
[
"speech_lengths"
])
audio_sample_len
=
np
.
sum
(
speech_lens
)
/
16000
batch_data
=
{
"speech"
:
speech
,
"speech_lengths"
:
speech_lens
}
# encoder_out_lens = np.array([np.sum(np.any(np.array(batch["speech"]) != 0, axis=1)) for i in range(np.array(batch["speech"]).shape[0])])
# encoder_inputs = {
# 'feats': np.array(batch["speech"]).astype(np.float32)}
if
isinstance
(
batch_data
[
"speech"
],
np
.
ndarray
):
batch_data
[
"speech"
]
=
torch
.
tensor
(
batch_data
[
"speech"
])
if
isinstance
(
batch_data
[
"speech_lengths"
],
np
.
ndarray
):
batch_data
[
"speech_lengths"
]
=
torch
.
tensor
(
batch_data
[
"speech_lengths"
])
batch
=
to_device
(
batch
,
device
=
'cuda'
)
feats
,
encoder_out_lens
=
speech2text
.
asr_model
.
encode
(
**
batch
)
batch_data
=
to_device
(
batch_data
,
device
=
'cuda'
)
feats
,
encoder_out_lens
=
speech2text
.
asr_model
.
pre_data
(
**
batch_data
)
encoder_out_lens
=
torch
.
ceil
(
encoder_out_lens
.
float
()
/
4
).
long
(
)
encoder_inputs
=
{
'feats'
:
feats
.
cpu
().
numpy
().
astype
(
np
.
float32
)}
return
encoder_inputs
,
encoder_out_lens
,
labels
,
names
,
audio_sample_len
ll_time
=
time
.
time
()
# encoder_time = time.time()
def
inference_step
(
encoder_inputs
,
encoder_out_lens
,
speech2text
,
full_lm_model
,
args
,
encoder_session
):
"""Perform inference on prepared data"""
# Run encoder inference
encoder_outputs
=
encoder_session
.
run
(
None
,
encoder_inputs
)
# encoder_out_1, encoder_out_lens_1 = encoder_session_io.get_outputs()
encoder_out_numpy
=
encoder_outputs
[
0
]
# encoder_out_lens = np.array(encoder_session_io.copy_outputs_to_cpu()[1])
encoder_out
=
torch
.
from_numpy
(
encoder_out_numpy
).
float
().
cuda
()
# encoder_out_lens = torch.from_numpy(encoder_out_lens_numpy).float().cuda()
# encoder_count = time.time() - encoder_time
# print("encode 耗时:", encoder_count)
# encoder_times.append(encoder_count)
# # ctc_log_probs: [N, T, C]
# ctc_time = time.time()
# # print("encoder_out:",encoder_out.size())
# # a. To device
# batch = to_device(batch, device='cuda')
# # b. Forward Encoder
# # enc: [N, T, C]
# # print(batch)
# encoder_time = time.time()
# encoder_out, encoder_out_lens = speech2text.asr_model.encode(**batch)
# encoder_count = time.time() - encoder_time
# print("encoder_out_lens:", encoder_out_lens, encoder_out_lens.size())
# print("encoder_out:", encoder_out.size())
# print("encode 耗时:", encoder_count)
# # **************************************************
# # encoder_out_lens: tensor([129, 105, 180, 171, 153, 199, 299, 211, 247, 222, 141, 277, 83, 197,
# # 179, 154, 148, 165, 178, 165, 179, 241, 288, 137], device='cuda:0') torch.Size([24])
# # encoder_out: torch.Size([24, 299, 256])
# encoder_times.append(encoder_count)
# #ctc_log_probs: [N, T, C]
# ctc_time = time.time()
ctc_log_probs
=
torch
.
nn
.
functional
.
log_softmax
(
speech2text
.
asr_model
.
ctc
.
ctc_lo
(
encoder_out
),
dim
=
2
)
...
...
@@ -320,9 +119,6 @@ if __name__ == '__main__':
beam_log_probs
,
beam_log_probs_idx
=
torch
.
topk
(
ctc_log_probs
,
args
.
beam_size
,
dim
=
2
)
# ctc_count = time.time() - ctc_time
# print("ctc 耗时:", ctc_count)
# ctc_times.append(ctc_count)
num_processes
=
min
(
multiprocessing
.
cpu_count
(),
args
.
batch_size
)
if
args
.
mode
==
'ctc_greedy_search'
:
...
...
@@ -334,20 +130,15 @@ if __name__ == '__main__':
hyps
=
map_batch
(
batch_sents
,
speech2text
.
asr_model
.
token_list
,
num_processes
,
True
,
0
)
else
:
# beam_search_time = time.time()
batch_log_probs_seq_list
=
beam_log_probs
.
tolist
()
batch_log_probs_idx_list
=
beam_log_probs_idx
.
tolist
()
batch_len_list
=
encoder_out_lens
.
tolist
()
# batch_len_list = encoder_out_lens
batch_log_probs_seq
=
[]
batch_log_probs_ids
=
[]
batch_start
=
[]
# only effective in streaming deployment
batch_root
=
TrieVector
()
root_dict
=
{}
for
i
in
range
(
len
(
batch_len_list
)):
# print(batch_len_list)
# num_sent = batch_len_list[i]
num_sent
=
encoder_out
.
size
()[
1
]
batch_log_probs_seq
.
append
(
batch_log_probs_seq_list
[
i
][
0
:
num_sent
])
...
...
@@ -364,12 +155,6 @@ if __name__ == '__main__':
num_processes
,
0
,
-
2
,
0.99999
)
# beam_search_count = time.time() - beam_search_time
# print("beam_search 耗时:", beam_search_count)
# beam_search_times.append(beam_search_count)
# beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs,
# args.beam_size, dim=2)
if
args
.
mode
==
'ctc_prefix_beam_search'
:
hyps
=
[]
for
cand_hyps
in
score_hyps
:
...
...
@@ -420,8 +205,6 @@ if __name__ == '__main__':
hyps_lens
=
hyps_lens_sos
.
view
(
B2
,)
hyps_pad_sos
=
hyps_pad_sos
.
view
(
B2
,
max_len
+
1
)
hyps_pad_eos
=
hyps_pad_eos
.
view
(
B2
,
max_len
+
1
)
#hyps_pad_sos = hyps_pad[:, :-1]
#hyps_pad_eos = hyps_pad[:, 1:]
decoder_out
,
_
=
speech2text
.
asr_model
.
decoder
(
encoder_out
,
encoder_out_lens
,
hyps_pad_sos
.
cuda
(),
hyps_lens
.
cuda
())
...
...
@@ -511,9 +294,7 @@ if __name__ == '__main__':
k
+=
args
.
beam_size
hyps
=
map_batch
(
best_sents
,
speech2text
.
asr_model
.
token_list
,
num_processes
)
elif
args
.
mode
==
'lm_rescoring'
:
# lm_time = time.time()
ctc_score
,
all_hyps
=
[],
[]
max_len
=
0
...
...
@@ -566,40 +347,158 @@ if __name__ == '__main__':
k
+=
args
.
beam_size
hyps
=
map_batch
(
best_sents
,
speech2text
.
asr_model
.
token_list
,
num_processes
)
count_time
=
time
.
time
()
-
ll_time
count_times
.
append
(
count_time
)
# lm_count = time.time() - lm_time
# print("lm 耗时:", lm_count)
# lm_times.append(lm_count)
# print("*"*50)
else
:
raise
NotImplementedError
return
hyps
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'recognize with your model'
)
parser
.
add_argument
(
'--config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--lm_config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--gpu'
,
type
=
int
,
default
=
0
,
help
=
'gpu id for this rank, -1 for cpu'
)
parser
.
add_argument
(
'--wav_scp'
,
required
=
True
,
help
=
'wav scp file'
)
parser
.
add_argument
(
'--text'
,
required
=
True
,
help
=
'ground truth text file'
)
parser
.
add_argument
(
'--model_path'
,
required
=
True
,
help
=
'torch pt model file'
)
parser
.
add_argument
(
'--lm_path'
,
required
=
True
,
help
=
'torch pt model file'
)
parser
.
add_argument
(
'--result_file'
,
default
=
'./predictions.txt'
,
help
=
'asr result file'
)
parser
.
add_argument
(
'--log_file'
,
default
=
'./rtf.txt'
,
help
=
'asr decoding log'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
24
,
help
=
'batch_size'
)
parser
.
add_argument
(
'--beam_size'
,
type
=
int
,
default
=
10
,
help
=
'beam_size'
)
parser
.
add_argument
(
'--mode'
,
choices
=
[
'ctc_greedy_search'
,
'ctc_prefix_beam_search'
,
'attention_rescoring'
,
'attention_lm_rescoring'
,
'lm_rescoring'
],
default
=
'attention_lm_rescoring'
,
help
=
'decoding mode'
)
args
=
parser
.
parse_args
()
return
args
if
__name__
==
'__main__'
:
args
=
get_args
()
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu
)
dataset
=
CustomAishellDataset
(
args
.
wav_scp
,
args
.
text
)
test_data_loader
=
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
collate_fn
=
collate_wrapper
)
speech2text
=
Speech2Text
(
args
.
config
,
args
.
model_path
,
None
,
args
.
lm_config
,
args
.
lm_path
,
device
=
"cuda"
)
full_lm_model
=
None
if
args
.
lm_config
is
not
None
and
args
.
lm_path
is
not
None
:
from
espnet2.tasks.lm
import
LMTask
full_lm_model
,
_
=
LMTask
.
build_model_from_file
(
args
.
lm_config
,
args
.
lm_path
,
"cuda"
)
full_lm_model
.
eval
()
import
onnxruntime
as
ort
providers
=
[
'ROCMExecutionProvider'
]
encoder_path
=
"/home/sunzhq/workspace/yidong-infer/conformer/onnx_models_batch24_1/transformer_lm/full/default_encoder_fp16.onnx"
encoder_session
=
ort
.
InferenceSession
(
encoder_path
,
providers
=
providers
)
output_names
=
[
"encoder_out"
,
"encoder_out_lens"
]
# Warmup: Run inference on first batch to initialize models and cache
print
(
"Starting warmup..."
)
warmup_start
=
time
.
time
()
with
torch
.
no_grad
():
for
i
,
batch
in
enumerate
(
test_data_loader
):
if
i
>=
1
:
# Warmup with first batch only
break
# Process batch data
encoder_inputs
,
encoder_out_lens
,
labels
,
names
,
audio_sample_len
=
process_batch_data
(
batch
,
speech2text
)
# Run inference
hyps
=
inference_step
(
encoder_inputs
,
encoder_out_lens
,
speech2text
,
full_lm_model
,
args
,
encoder_session
)
print
(
f
"Warmup completed in
{
time
.
time
()
-
warmup_start
:.
2
f
}
seconds"
)
# Main inference loop
time_start
=
time
.
perf_counter
()
audio_sample_len_total
=
0
infer_times
=
[]
total_infer_times
=
[]
total_start
=
time
.
time
()
# Open files for saving results in the required format
with
torch
.
no_grad
(),
open
(
args
.
result_file
,
'w'
)
as
fout
,
open
(
'ref.trn'
,
'w'
)
as
ref_file
,
open
(
'hyp.trn'
,
'w'
)
as
hyp_file
:
for
batch_idx
,
batch
in
enumerate
(
test_data_loader
):
# Process batch data (separated from inference)
encoder_inputs
,
encoder_out_lens
,
labels
,
names
,
audio_sample_len
=
process_batch_data
(
batch
,
speech2text
)
audio_sample_len_total
+=
audio_sample_len
# Measure inference time
infer_start
=
time
.
time
()
# Run inference
hyps
=
inference_step
(
encoder_inputs
,
encoder_out_lens
,
speech2text
,
full_lm_model
,
args
,
encoder_session
)
infer_time
=
time
.
time
()
-
infer_start
infer_times
.
append
(
infer_time
)
# Save results
for
i
,
key
in
enumerate
(
names
):
content
=
hyps
[
i
]
# print('{} {}'.format(key, content))
fout
.
write
(
'{} {}
\n
'
.
format
(
key
,
content
))
# Save to ref.trn and hyp.trn in the required format
# Convert continuous Chinese text to space-separated characters
ref_text
=
' '
.
join
(
labels
[
i
])
hyp_text
=
' '
.
join
(
content
)
ref_file
.
write
(
'{}
\t
({})
\n
'
.
format
(
ref_text
,
key
))
hyp_file
.
write
(
'{}
\t
({})
\n
'
.
format
(
hyp_text
,
key
))
# print(f"Batch {batch_idx + 1} processed in {infer_time:.3f} seconds")
total_infer_times
.
append
(
time
.
time
()
-
total_start
)
total_start
=
time
.
time
()
# Calculate and print statistics
time_end
=
time
.
perf_counter
()
-
time_start
# encoder_times = encoder_times[5:]
# ctc_times = ctc_times[5:]
# beam_search_times = beam_search_times[5:]
# lm_times = lm_times[5:]
# mean_encoder = np.mean(encoder_times)
# mean_ctc = np.mean(ctc_times)
# mean_beam_search = np.mean(beam_search_times)
# mean_lm = np.mean(lm_times)
# print("平均 encode time:", mean_encoder)
# print("平均 ctc time:", mean_ctc)
# print("平均 beam_search time:", mean_beam_search)
# print("平均 lm time:", mean_lm)
count_times
=
count_times
[
5
:]
mean_count_time
=
np
.
mean
(
count_times
)
print
(
"平均 mean_count_time:"
,
mean_count_time
,
" fps: "
,
24
/
mean_count_time
)
# if str(args.gpu) == '0':
# Exclude first few batches for warmup
# if len(infer_times) > 5:
# stable_infer_times = infer_times[5:]
# mean_infer_time = np.mean(stable_infer_times)
# print(f"Average inference time (excluding warmup): {mean_infer_time:.3f} seconds")
# print(f"FPS: {args.batch_size/mean_infer_time:.1f}")
print
(
f
"Total audio processed:
{
audio_sample_len_total
:.
1
f
}
seconds"
)
print
(
f
"Total time:
{
time_end
:.
1
f
}
seconds"
)
print
(
f
"Real-time factor (RTF):
{
time_end
/
audio_sample_len_total
:.
3
f
}
"
)
print
(
"***************************"
)
infer_time
=
sum
(
infer_times
)
avg_infer_fps
=
24
*
len
(
infer_times
)
/
sum
(
infer_times
)
print
(
f
"total_infer_time:
{
infer_time
}
s"
)
print
(
f
'avg_infer_fps:
{
avg_infer_fps
}
samples/s'
)
load_data_infer_time
=
sum
(
total_infer_times
)
load_data_avg_infer_fps
=
len
(
total_infer_times
)
*
24
/
sum
(
total_infer_times
)
print
(
f
'load_data_total_infer_time:
{
load_data_infer_time
}
s'
)
print
(
f
'load_data_avg_total_Infer_fps:
{
load_data_avg_infer_fps
}
samples/s'
)
print
(
"******************************"
)
with
open
(
args
.
log_file
,
'w'
)
as
log
:
log
.
write
(
f
"Decoding audio
{
audio_sample_len
}
secs, cost
{
time_end
}
secs, RTF:
{
time_end
/
audio_sample_len
}
, process
{
audio_sample_len
/
time_end
}
secs audio per second, decoding args:
{
args
}
"
)
log
.
write
(
f
"Decoding audio
{
audio_sample_len
_total
}
secs, cost
{
time_end
}
secs, RTF:
{
time_end
/
audio_sample_len
_total
}
, process
{
audio_sample_len
_total
/
time_end
}
secs audio per second, decoding args:
{
args
}
"
)
\ No newline at end of file
conformer/torch-infer/infer.sh
0 → 100644
View file @
0941998c
#!/usr/bin/bash
# if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# asr_train_config="/home/sunzhq/workspace/yidong-infer/conformer/34e9cabc2c29fd0e3a2917ffa525d98b/exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/config.yaml"
# asr_model_file="/home/sunzhq/workspace/yidong-infer/conformer/34e9cabc2c29fd0e3a2917ffa525d98b/exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/valid.acc.ave_10best.pth"
# lm_train_config=/home/sunzhq/workspace/yidong-infer/conformer/34e9cabc2c29fd0e3a2917ffa525d98b/exp/lm_train_lm_transformer_char_batch_bins2000000/config.yaml
# lm_path=/home/sunzhq/workspace/yidong-infer/conformer/34e9cabc2c29fd0e3a2917ffa525d98b/exp/lm_train_lm_transformer_char_batch_bins2000000/valid.loss.ave_10best.pth
# manifest="/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/test"
asr_train_config
=
"/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/config.yaml"
asr_model_file
=
"/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/valid.acc.ave_10best.pth"
lm_train_config
=
/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/exp/lm_train_lm_transformer_char_batch_bins2000000/config.yaml
lm_path
=
/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/exp/lm_train_lm_transformer_char_batch_bins2000000/valid.loss.ave_10best.pth
manifest
=
"/home/sunzhq/workspace/yidong-infer/conformer/torch-infer/test"
mkdir
-p
logs
# mode='attention_rescoring'
mode
=
'lm_rescoring'
# num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
export
HIP_VISIBLE_DEVICES
=
0
nohup
numactl
-N
0
-m
0 python3 infer.py
\
--config
$asr_train_config
\
--model_path
$asr_model_file
\
--lm_config
$lm_train_config
\
--lm_path
$lm_path
\
--gpu
0
\
--wav_scp
$manifest
/wav.scp
--text
$manifest
/text
\
--result_file
./logs/predictions_
${
mode
}
_
$gpu_id
.txt
\
--log_file
./logs/log_
${
mode
}
_
$gpu_id
.txt
\
--batch_size
24
--beam_size
10
\
--mode
$mode
2>&1 |
tee
result_0.log &
export
HIP_VISIBLE_DEVICES
=
1
nohup
numactl
-N
1
-m
1 python3 infer.py
\
--config
$asr_train_config
\
--model_path
$asr_model_file
\
--lm_config
$lm_train_config
\
--lm_path
$lm_path
\
--gpu
0
\
--wav_scp
$manifest
/wav.scp
--text
$manifest
/text
\
--result_file
./logs/predictions_
${
mode
}
_
$gpu_id
.txt
\
--log_file
./logs/log_
${
mode
}
_
$gpu_id
.txt
\
--batch_size
24
--beam_size
10
\
--mode
$mode
2>&1 |
tee
result_1.log &
export
HIP_VISIBLE_DEVICES
=
2
nohup
numactl
-N
2
-m
2 python3 infer.py
\
--config
$asr_train_config
\
--model_path
$asr_model_file
\
--lm_config
$lm_train_config
\
--lm_path
$lm_path
\
--gpu
0
\
--wav_scp
$manifest
/wav.scp
--text
$manifest
/text
\
--result_file
./logs/predictions_
${
mode
}
_
$gpu_id
.txt
\
--log_file
./logs/log_
${
mode
}
_
$gpu_id
.txt
\
--batch_size
24
--beam_size
10
\
--mode
$mode
2>&1 |
tee
result_2.log &
export
HIP_VISIBLE_DEVICES
=
3
nohup
numactl
-N
3
-m
3 python3 infer.py
\
--config
$asr_train_config
\
--model_path
$asr_model_file
\
--lm_config
$lm_train_config
\
--lm_path
$lm_path
\
--gpu
0
\
--wav_scp
$manifest
/wav.scp
--text
$manifest
/text
\
--result_file
./logs/predictions_
${
mode
}
_
$gpu_id
.txt
\
--log_file
./logs/log_
${
mode
}
_
$gpu_id
.txt
\
--batch_size
24
--beam_size
10
\
--mode
$mode
2>&1 |
tee
result_3.log &
conformer/torch-infer/infer
.bak
.py
→
conformer/torch-infer/infer
_io
.py
View file @
0941998c
...
...
@@ -20,97 +20,6 @@ except ImportError:
'https://github.com/Slyne/ctc_decoder.git'
)
sys
.
exit
(
1
)
def
lm_batchify_nll
(
lm_scorer
,
text
:
torch
.
Tensor
,
text_lengths
:
torch
.
Tensor
,
batch_size
:
int
=
100
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute negative log likelihood(nll) from transformer language model using lm_scorer
To avoid OOM, this function separates the input into batches.
Then call batch_score for each batch and combine and return results.
Args:
lm_scorer: Language model scorer object
text: (Batch, Length)
text_lengths: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
"""
total_num
=
text
.
size
(
0
)
if
total_num
<=
batch_size
:
nll
,
x_lengths
=
_compute_nll_with_lm_scorer
(
lm_scorer
,
text
,
text_lengths
)
else
:
nlls
=
[]
x_lengths
=
[]
max_length
=
text_lengths
.
max
()
start_idx
=
0
while
True
:
end_idx
=
min
(
start_idx
+
batch_size
,
total_num
)
batch_text
=
text
[
start_idx
:
end_idx
,
:]
batch_text_lengths
=
text_lengths
[
start_idx
:
end_idx
]
# batch_nll: [B * T]
batch_nll
,
batch_x_lengths
=
_compute_nll_with_lm_scorer
(
lm_scorer
,
batch_text
,
batch_text_lengths
,
max_length
=
max_length
)
nlls
.
append
(
batch_nll
)
x_lengths
.
append
(
batch_x_lengths
)
start_idx
=
end_idx
if
start_idx
==
total_num
:
break
nll
=
torch
.
cat
(
nlls
)
x_lengths
=
torch
.
cat
(
x_lengths
)
assert
nll
.
size
(
0
)
==
total_num
assert
x_lengths
.
size
(
0
)
==
total_num
return
nll
,
x_lengths
def
_compute_nll_with_lm_scorer
(
lm_scorer
,
text
:
torch
.
Tensor
,
text_lengths
:
torch
.
Tensor
,
max_length
:
int
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute negative log likelihood using lm_scorer's score method
This function simulates the nll method using the available score method
from the lm_scorer object.
"""
batch_size
=
text
.
size
(
0
)
# For data parallel
if
max_length
is
None
:
text
=
text
[:,
:
text_lengths
.
max
()]
else
:
text
=
text
[:,
:
max_length
]
# Initialize nll for each sequence
nll
=
torch
.
zeros
(
batch_size
,
device
=
text
.
device
)
# Process each sequence individually
for
batch_idx
in
range
(
batch_size
):
seq_text
=
text
[
batch_idx
]
seq_length
=
text_lengths
[
batch_idx
]
# Truncate to actual sequence length
seq_text
=
seq_text
[:
seq_length
]
# Initialize state for this sequence
state
=
None
# Process each token position sequentially
for
pos
in
range
(
len
(
seq_text
)
-
1
):
# Get current token
current_token
=
seq_text
[
pos
].
unsqueeze
(
0
)
# shape: (1,)
# Score the current token
logp
,
state
=
lm_scorer
.
score
(
current_token
,
state
,
None
)
# Get the ground truth next token
next_token
=
seq_text
[
pos
+
1
]
# Get the negative log likelihood for the correct next token
token_nll
=
-
logp
[
next_token
]
nll
[
batch_idx
]
+=
token_nll
# x_lengths is text_lengths - 1 (since we score transitions between tokens)
x_lengths
=
text_lengths
-
1
x_lengths
=
torch
.
clamp
(
x_lengths
,
min
=
0
)
# Ensure non-negative
return
nll
,
x_lengths
class
CustomAishellDataset
(
Dataset
):
def
__init__
(
self
,
wav_scp_file
,
text_file
):
...
...
@@ -210,9 +119,10 @@ if __name__ == '__main__':
args
=
get_args
()
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu
)
dataset
=
CustomAishellDataset
(
args
.
wav_scp
,
args
.
text
)
# test_data_loader = DataLoader(dataset, batch_size=args.batch_size,
# collate_fn=collate_wrapper)
test_data_loader
=
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
collate_fn
=
collate_wrapper
)
speech2text
=
Speech2Text
(
args
.
config
,
args
.
model_path
,
...
...
@@ -231,6 +141,20 @@ if __name__ == '__main__':
args
.
lm_config
,
args
.
lm_path
,
"cuda"
)
full_lm_model
.
eval
()
import
onnxruntime
as
ort
sess_options
=
ort
.
SessionOptions
()
sess_options
.
graph_optimization_level
=
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_ALL
sess_options
.
enable_cpu_mem_arena
=
False
sess_options
.
enable_mem_pattern
=
False
providers
=
[
'ROCMExecutionProvider'
]
encoder_path
=
"/home/sunzhq/workspace/yidong-infer/conformer/onnx_models_batch24_1/transformer_lm/full/default_encoder_fp16.onnx"
encoder_session
=
ort
.
InferenceSession
(
encoder_path
,
providers
=
providers
)
encoder_session_io
=
encoder_session
.
io_binding
()
output_names
=
[
"encoder_out"
,
"encoder_out_lens"
]
time_start
=
time
.
perf_counter
()
audio_sample_len
=
0
...
...
@@ -239,6 +163,7 @@ if __name__ == '__main__':
decoder_times
=
[]
lm_times
=
[]
beam_search_times
=
[]
count_times
=
[]
with
torch
.
no_grad
(),
open
(
args
.
result_file
,
'w'
)
as
fout
:
for
_
,
batch
in
enumerate
(
test_data_loader
):
speech
,
speech_lens
,
labels
,
names
=
batch
...
...
@@ -250,28 +175,67 @@ if __name__ == '__main__':
if
isinstance
(
batch
[
"speech_lengths"
],
np
.
ndarray
):
batch
[
"speech_lengths"
]
=
torch
.
tensor
(
batch
[
"speech_lengths"
])
# a. To device
batch
=
to_device
(
batch
,
device
=
'cuda'
)
feats
,
encoder_out_lens
=
speech2text
.
asr_model
.
pre_data
(
**
batch
)
encoder_out_lens
=
torch
.
ceil
(
encoder_out_lens
.
float
()
/
4
).
long
()
encoder_inputs
=
{
'feats'
:
feats
.
cpu
().
numpy
().
astype
(
np
.
float32
)}
inputData
=
{}
for
key
in
encoder_inputs
.
keys
():
inputData
[
key
]
=
ort
.
OrtValue
.
ortvalue_from_numpy
(
encoder_inputs
[
key
],
device_type
=
'cuda'
)
encoder_session_io
.
bind_input
(
name
=
key
,
device_type
=
inputData
[
key
].
device_name
(),
device_id
=
0
,
element_type
=
np
.
float32
,
shape
=
inputData
[
key
].
shape
(),
buffer_ptr
=
inputData
[
key
].
data_ptr
())
# for o_n in output_names:
# encoder_session_io.bind_output("encoder_out")
encoder_session_io
.
bind_output
(
name
=
"encoder_out"
,
device_type
=
"cuda"
,
device_id
=
0
)
ll_time
=
time
.
time
()
encoder_session
.
run_with_iobinding
(
encoder_session_io
)
outputs
=
encoder_session_io
.
get_outputs
()[
0
]
ptr
=
outputs
.
data_ptr
()
# GPU 内存地址
shape
=
outputs
.
shape
()
dtype
=
torch
.
float32
total_elements
=
np
.
prod
(
shape
)
element_size
=
4
total_bytes
=
total_elements
*
element_size
methods
=
[
m
for
m
in
dir
(
outputs
)
if
not
m
.
startswith
(
'_'
)]
print
(
methods
)
# encoder_out = torch.as_tensor(ptr, dtype=dtype, device='cuda').reshape(shape)
print
(
outputs
)
print
(
outputs
.
device_name
())
print
(
"Has to_dlpack:"
,
hasattr
(
outputs
,
'to_dlpack'
))
print
(
"Shape:"
,
outputs
.
shape
())
# encoder_out = torch.from_dlpack(outputs[0].to_dlpack())
# result = encoder_session_io.copy_outputs_to_cpu()
# encoder_out = torch.tensor(result[0]).float().cuda()
print
(
encoder_out
)
# # encoder_time = time.time()
# encoder_outputs = encoder_session.run(None, encoder_inputs)
# # encoder_out_1, encoder_out_lens_1 = encoder_session_io.get_outputs()
# encoder_out_numpy = encoder_outputs[0]
# # encoder_out_lens = np.array(encoder_session_io.copy_outputs_to_cpu()[1])
# encoder_out = torch.from_numpy(encoder_out_numpy).float().cuda()
# print(encoder_out.size())
# b. Forward Encoder
# enc: [N, T, C]
encoder_time
=
time
.
time
()
encoder_out
,
encoder_out_lens
=
speech2text
.
asr_model
.
encode
(
**
batch
)
encoder_count
=
time
.
time
()
-
encoder_time
print
(
"encode 耗时:"
,
encoder_count
)
encoder_times
.
append
(
encoder_count
)
# ctc_log_probs: [N, T, C]
ctc_time
=
time
.
time
()
ctc_log_probs
=
torch
.
nn
.
functional
.
log_softmax
(
speech2text
.
asr_model
.
ctc
.
ctc_lo
(
encoder_out
),
dim
=
2
)
ctc_count
=
time
.
time
()
-
ctc_time
print
(
"ctc 耗时:"
,
ctc_count
)
ctc_times
.
append
(
ctc_count
)
beam_log_probs
,
beam_log_probs_idx
=
torch
.
topk
(
ctc_log_probs
,
args
.
beam_size
,
dim
=
2
)
# ctc_count = time.time() - ctc_time
# print("ctc 耗时:", ctc_count)
# ctc_times.append(ctc_count)
num_processes
=
min
(
multiprocessing
.
cpu_count
(),
args
.
batch_size
)
if
args
.
mode
==
'ctc_greedy_search'
:
...
...
@@ -283,18 +247,21 @@ if __name__ == '__main__':
hyps
=
map_batch
(
batch_sents
,
speech2text
.
asr_model
.
token_list
,
num_processes
,
True
,
0
)
else
:
beam_search_time
=
time
.
time
()
#
beam_search_time = time.time()
batch_log_probs_seq_list
=
beam_log_probs
.
tolist
()
batch_log_probs_idx_list
=
beam_log_probs_idx
.
tolist
()
batch_len_list
=
encoder_out_lens
.
tolist
()
# batch_len_list = encoder_out_lens
batch_log_probs_seq
=
[]
batch_log_probs_ids
=
[]
batch_start
=
[]
# only effective in streaming deployment
batch_root
=
TrieVector
()
root_dict
=
{}
for
i
in
range
(
len
(
batch_len_list
)):
num_sent
=
batch_len_list
[
i
]
# print(batch_len_list)
# num_sent = batch_len_list[i]
num_sent
=
encoder_out
.
size
()[
1
]
batch_log_probs_seq
.
append
(
batch_log_probs_seq_list
[
i
][
0
:
num_sent
])
batch_log_probs_ids
.
append
(
...
...
@@ -310,9 +277,9 @@ if __name__ == '__main__':
num_processes
,
0
,
-
2
,
0.99999
)
beam_search_count
=
time
.
time
()
-
beam_search_time
print
(
"beam_search 耗时:"
,
beam_search_count
)
beam_search_times
.
append
(
beam_search_count
)
#
beam_search_count = time.time() - beam_search_time
#
print("beam_search 耗时:", beam_search_count)
#
beam_search_times.append(beam_search_count)
# beam_log_probs, beam_log_probs_idx = torch.topk(ctc_log_probs,
# args.beam_size, dim=2)
...
...
@@ -459,7 +426,7 @@ if __name__ == '__main__':
elif
args
.
mode
==
'lm_rescoring'
:
lm_time
=
time
.
time
()
#
lm_time = time.time()
ctc_score
,
all_hyps
=
[],
[]
max_len
=
0
...
...
@@ -512,11 +479,13 @@ if __name__ == '__main__':
k
+=
args
.
beam_size
hyps
=
map_batch
(
best_sents
,
speech2text
.
asr_model
.
token_list
,
num_processes
)
count_time
=
time
.
time
()
-
ll_time
count_times
.
append
(
count_time
)
lm_count
=
time
.
time
()
-
lm_time
print
(
"lm 耗时:"
,
lm_count
)
lm_times
.
append
(
lm_count
)
print
(
"*"
*
50
)
#
lm_count = time.time() - lm_time
#
print("lm 耗时:", lm_count)
#
lm_times.append(lm_count)
#
print("*"*50)
else
:
raise
NotImplementedError
...
...
@@ -528,20 +497,22 @@ if __name__ == '__main__':
fout
.
write
(
'{} {}
\n
'
.
format
(
key
,
content
))
time_end
=
time
.
perf_counter
()
-
time_start
encoder_times
=
encoder_times
[
5
:]
ctc_times
=
ctc_times
[
5
:]
beam_search_times
=
beam_search_times
[
5
:]
lm_times
=
lm_times
[
5
:]
mean_encoder
=
np
.
mean
(
encoder_times
)
mean_ctc
=
np
.
mean
(
ctc_times
)
mean_beam_search
=
np
.
mean
(
beam_search_times
)
mean_lm
=
np
.
mean
(
lm_times
)
print
(
"平均 encode time:"
,
mean_encoder
)
print
(
"平均 ctc time:"
,
mean_ctc
)
print
(
"平均 beam_search time:"
,
mean_beam_search
)
print
(
"平均 lm time:"
,
mean_lm
)
# encoder_times = encoder_times[5:]
# ctc_times = ctc_times[5:]
# beam_search_times = beam_search_times[5:]
# lm_times = lm_times[5:]
# mean_encoder = np.mean(encoder_times)
# mean_ctc = np.mean(ctc_times)
# mean_beam_search = np.mean(beam_search_times)
# mean_lm = np.mean(lm_times)
# print("平均 encode time:", mean_encoder)
# print("平均 ctc time:", mean_ctc)
# print("平均 beam_search time:", mean_beam_search)
# print("平均 lm time:", mean_lm)
count_times
=
count_times
[
5
:]
mean_count_time
=
np
.
mean
(
count_times
)
print
(
"平均 mean_count_time:"
,
mean_count_time
,
" fps: "
,
24
/
mean_count_time
)
# if str(args.gpu) == '0':
with
open
(
args
.
log_file
,
'w'
)
as
log
:
log
.
write
(
f
"Decoding audio
{
audio_sample_len
}
secs, cost
{
time_end
}
secs, RTF:
{
time_end
/
audio_sample_len
}
, process
{
audio_sample_len
/
time_end
}
secs audio per second, decoding args:
{
args
}
"
)
conformer/torch-infer/logs/wer_lm_rescoring_0
deleted
100644 → 0
View file @
fde49a28
This source diff could not be displayed because it is too large. You can
view the blob
instead.
conformer/torch-infer/meta.yaml
0 → 100644
View file @
0941998c
espnet
:
0.9.0
files
:
asr_model_file
:
exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/valid.acc.ave_10best.pth
lm_file
:
exp/lm_train_lm_transformer_char_batch_bins2000000/valid.loss.ave_10best.pth
python
:
"
3.7.3
(default,
Mar
27
2019,
22:11:17)
\n
[GCC
7.3.0]"
timestamp
:
1603088092.704853
torch
:
1.6.0
yaml_files
:
asr_train_config
:
exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/config.yaml
lm_train_config
:
exp/lm_train_lm_transformer_char_batch_bins2000000/config.yaml
conformer/torch-infer/post.sh
0 → 100644
View file @
0941998c
python3 conformer-compute-wer.py ./logs/ref.trn ./logs/hyp.trn
\ No newline at end of file
espnet_model.py
0 → 100644
View file @
0941998c
import
logging
from
contextlib
import
contextmanager
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
packaging.version
import
parse
as
V
from
typeguard
import
check_argument_types
from
espnet2.asr.ctc
import
CTC
from
espnet2.asr.decoder.abs_decoder
import
AbsDecoder
from
espnet2.asr.encoder.abs_encoder
import
AbsEncoder
from
espnet2.asr.frontend.abs_frontend
import
AbsFrontend
from
espnet2.asr.postencoder.abs_postencoder
import
AbsPostEncoder
from
espnet2.asr.preencoder.abs_preencoder
import
AbsPreEncoder
from
espnet2.asr.specaug.abs_specaug
import
AbsSpecAug
from
espnet2.asr.transducer.error_calculator
import
ErrorCalculatorTransducer
from
espnet2.asr_transducer.utils
import
get_transducer_task_io
from
espnet2.layers.abs_normalize
import
AbsNormalize
from
espnet2.torch_utils.device_funcs
import
force_gatherable
from
espnet2.train.abs_espnet_model
import
AbsESPnetModel
from
espnet.nets.e2e_asr_common
import
ErrorCalculator
from
espnet.nets.pytorch_backend.nets_utils
import
th_accuracy
from
espnet.nets.pytorch_backend.transformer.add_sos_eos
import
add_sos_eos
from
espnet.nets.pytorch_backend.transformer.label_smoothing_loss
import
(
# noqa: H301
LabelSmoothingLoss
,
)
if
V
(
torch
.
__version__
)
>=
V
(
"1.6.0"
):
from
torch.cuda.amp
import
autocast
else
:
# Nothing to do if torch<1.6.0
@
contextmanager
def
autocast
(
enabled
=
True
):
yield
class
ESPnetASRModel
(
AbsESPnetModel
):
"""CTC-attention hybrid Encoder-Decoder model"""
def
__init__
(
self
,
vocab_size
:
int
,
token_list
:
Union
[
Tuple
[
str
,
...],
List
[
str
]],
frontend
:
Optional
[
AbsFrontend
],
specaug
:
Optional
[
AbsSpecAug
],
normalize
:
Optional
[
AbsNormalize
],
preencoder
:
Optional
[
AbsPreEncoder
],
encoder
:
AbsEncoder
,
postencoder
:
Optional
[
AbsPostEncoder
],
decoder
:
Optional
[
AbsDecoder
],
ctc
:
CTC
,
joint_network
:
Optional
[
torch
.
nn
.
Module
],
aux_ctc
:
dict
=
None
,
ctc_weight
:
float
=
0.5
,
interctc_weight
:
float
=
0.0
,
ignore_id
:
int
=
-
1
,
lsm_weight
:
float
=
0.0
,
length_normalized_loss
:
bool
=
False
,
report_cer
:
bool
=
True
,
report_wer
:
bool
=
True
,
sym_space
:
str
=
"<space>"
,
sym_blank
:
str
=
"<blank>"
,
transducer_multi_blank_durations
:
List
=
[],
transducer_multi_blank_sigma
:
float
=
0.05
,
# In a regular ESPnet recipe, <sos> and <eos> are both "<sos/eos>"
# Pretrained HF Tokenizer needs custom sym_sos and sym_eos
sym_sos
:
str
=
"<sos/eos>"
,
sym_eos
:
str
=
"<sos/eos>"
,
extract_feats_in_collect_stats
:
bool
=
True
,
lang_token_id
:
int
=
-
1
,
):
assert
check_argument_types
()
assert
0.0
<=
ctc_weight
<=
1.0
,
ctc_weight
assert
0.0
<=
interctc_weight
<
1.0
,
interctc_weight
super
().
__init__
()
# NOTE (Shih-Lun): else case is for OpenAI Whisper ASR model,
# which doesn't use <blank> token
if
sym_blank
in
token_list
:
self
.
blank_id
=
token_list
.
index
(
sym_blank
)
else
:
self
.
blank_id
=
0
if
sym_sos
in
token_list
:
self
.
sos
=
token_list
.
index
(
sym_sos
)
else
:
self
.
sos
=
vocab_size
-
1
if
sym_eos
in
token_list
:
self
.
eos
=
token_list
.
index
(
sym_eos
)
else
:
self
.
eos
=
vocab_size
-
1
self
.
vocab_size
=
vocab_size
self
.
ignore_id
=
ignore_id
self
.
ctc_weight
=
ctc_weight
self
.
interctc_weight
=
interctc_weight
self
.
aux_ctc
=
aux_ctc
self
.
token_list
=
token_list
.
copy
()
#print("frontend:", frontend)
self
.
frontend
=
frontend
self
.
specaug
=
specaug
self
.
normalize
=
normalize
self
.
preencoder
=
preencoder
self
.
postencoder
=
postencoder
self
.
encoder
=
encoder
if
not
hasattr
(
self
.
encoder
,
"interctc_use_conditioning"
):
self
.
encoder
.
interctc_use_conditioning
=
False
if
self
.
encoder
.
interctc_use_conditioning
:
self
.
encoder
.
conditioning_layer
=
torch
.
nn
.
Linear
(
vocab_size
,
self
.
encoder
.
output_size
()
)
self
.
use_transducer_decoder
=
joint_network
is
not
None
self
.
error_calculator
=
None
if
self
.
use_transducer_decoder
:
self
.
decoder
=
decoder
self
.
joint_network
=
joint_network
if
not
transducer_multi_blank_durations
:
from
warprnnt_pytorch
import
RNNTLoss
self
.
criterion_transducer
=
RNNTLoss
(
blank
=
self
.
blank_id
,
fastemit_lambda
=
0.0
,
)
else
:
from
espnet2.asr.transducer.rnnt_multi_blank.rnnt_multi_blank
import
(
MultiblankRNNTLossNumba
,
)
self
.
criterion_transducer
=
MultiblankRNNTLossNumba
(
blank
=
self
.
blank_id
,
big_blank_durations
=
transducer_multi_blank_durations
,
sigma
=
transducer_multi_blank_sigma
,
reduction
=
"mean"
,
fastemit_lambda
=
0.0
,
)
self
.
transducer_multi_blank_durations
=
transducer_multi_blank_durations
if
report_cer
or
report_wer
:
self
.
error_calculator_trans
=
ErrorCalculatorTransducer
(
decoder
,
joint_network
,
token_list
,
sym_space
,
sym_blank
,
report_cer
=
report_cer
,
report_wer
=
report_wer
,
)
else
:
self
.
error_calculator_trans
=
None
if
self
.
ctc_weight
!=
0
:
self
.
error_calculator
=
ErrorCalculator
(
token_list
,
sym_space
,
sym_blank
,
report_cer
,
report_wer
)
else
:
# we set self.decoder = None in the CTC mode since
# self.decoder parameters were never used and PyTorch complained
# and threw an Exception in the multi-GPU experiment.
# thanks Jeff Farris for pointing out the issue.
if
ctc_weight
<
1.0
:
assert
(
decoder
is
not
None
),
"decoder should not be None when attention is used"
else
:
decoder
=
None
logging
.
warning
(
"Set decoder to none as ctc_weight==1.0"
)
self
.
decoder
=
decoder
self
.
criterion_att
=
LabelSmoothingLoss
(
size
=
vocab_size
,
padding_idx
=
ignore_id
,
smoothing
=
lsm_weight
,
normalize_length
=
length_normalized_loss
,
)
if
report_cer
or
report_wer
:
self
.
error_calculator
=
ErrorCalculator
(
token_list
,
sym_space
,
sym_blank
,
report_cer
,
report_wer
)
if
ctc_weight
==
0.0
:
self
.
ctc
=
None
else
:
self
.
ctc
=
ctc
self
.
extract_feats_in_collect_stats
=
extract_feats_in_collect_stats
self
.
is_encoder_whisper
=
"Whisper"
in
type
(
self
.
encoder
).
__name__
if
self
.
is_encoder_whisper
:
assert
(
self
.
frontend
is
None
),
"frontend should be None when using full Whisper model"
if
lang_token_id
!=
-
1
:
self
.
lang_token_id
=
torch
.
tensor
([[
lang_token_id
]])
else
:
self
.
lang_token_id
=
None
def
forward
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
text
:
torch
.
Tensor
,
text_lengths
:
torch
.
Tensor
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
kwargs: "utt_id" is among the input.
"""
assert
text_lengths
.
dim
()
==
1
,
text_lengths
.
shape
# Check that batch_size is unified
assert
(
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
==
text
.
shape
[
0
]
==
text_lengths
.
shape
[
0
]
),
(
speech
.
shape
,
speech_lengths
.
shape
,
text
.
shape
,
text_lengths
.
shape
)
batch_size
=
speech
.
shape
[
0
]
text
[
text
==
-
1
]
=
self
.
ignore_id
# for data-parallel
text
=
text
[:,
:
text_lengths
.
max
()]
# 1. Encoder
encoder_out
,
encoder_out_lens
=
self
.
encode
(
speech
,
speech_lengths
)
intermediate_outs
=
None
if
isinstance
(
encoder_out
,
tuple
):
intermediate_outs
=
encoder_out
[
1
]
encoder_out
=
encoder_out
[
0
]
loss_att
,
acc_att
,
cer_att
,
wer_att
=
None
,
None
,
None
,
None
loss_ctc
,
cer_ctc
=
None
,
None
loss_transducer
,
cer_transducer
,
wer_transducer
=
None
,
None
,
None
stats
=
dict
()
# 1. CTC branch
if
self
.
ctc_weight
!=
0.0
:
loss_ctc
,
cer_ctc
=
self
.
_calc_ctc_loss
(
encoder_out
,
encoder_out_lens
,
text
,
text_lengths
)
# Collect CTC branch stats
stats
[
"loss_ctc"
]
=
loss_ctc
.
detach
()
if
loss_ctc
is
not
None
else
None
stats
[
"cer_ctc"
]
=
cer_ctc
# Intermediate CTC (optional)
loss_interctc
=
0.0
if
self
.
interctc_weight
!=
0.0
and
intermediate_outs
is
not
None
:
for
layer_idx
,
intermediate_out
in
intermediate_outs
:
# we assume intermediate_out has the same length & padding
# as those of encoder_out
# use auxillary ctc data if specified
loss_ic
=
None
if
self
.
aux_ctc
is
not
None
:
idx_key
=
str
(
layer_idx
)
if
idx_key
in
self
.
aux_ctc
:
aux_data_key
=
self
.
aux_ctc
[
idx_key
]
aux_data_tensor
=
kwargs
.
get
(
aux_data_key
,
None
)
aux_data_lengths
=
kwargs
.
get
(
aux_data_key
+
"_lengths"
,
None
)
if
aux_data_tensor
is
not
None
and
aux_data_lengths
is
not
None
:
loss_ic
,
cer_ic
=
self
.
_calc_ctc_loss
(
intermediate_out
,
encoder_out_lens
,
aux_data_tensor
,
aux_data_lengths
,
)
else
:
raise
Exception
(
"Aux. CTC tasks were specified but no data was found"
)
if
loss_ic
is
None
:
loss_ic
,
cer_ic
=
self
.
_calc_ctc_loss
(
intermediate_out
,
encoder_out_lens
,
text
,
text_lengths
)
loss_interctc
=
loss_interctc
+
loss_ic
# Collect Intermedaite CTC stats
stats
[
"loss_interctc_layer{}"
.
format
(
layer_idx
)]
=
(
loss_ic
.
detach
()
if
loss_ic
is
not
None
else
None
)
stats
[
"cer_interctc_layer{}"
.
format
(
layer_idx
)]
=
cer_ic
loss_interctc
=
loss_interctc
/
len
(
intermediate_outs
)
# calculate whole encoder loss
loss_ctc
=
(
1
-
self
.
interctc_weight
)
*
loss_ctc
+
self
.
interctc_weight
*
loss_interctc
if
self
.
use_transducer_decoder
:
# 2a. Transducer decoder branch
(
loss_transducer
,
cer_transducer
,
wer_transducer
,
)
=
self
.
_calc_transducer_loss
(
encoder_out
,
encoder_out_lens
,
text
,
)
if
loss_ctc
is
not
None
:
loss
=
loss_transducer
+
(
self
.
ctc_weight
*
loss_ctc
)
else
:
loss
=
loss_transducer
# Collect Transducer branch stats
stats
[
"loss_transducer"
]
=
(
loss_transducer
.
detach
()
if
loss_transducer
is
not
None
else
None
)
stats
[
"cer_transducer"
]
=
cer_transducer
stats
[
"wer_transducer"
]
=
wer_transducer
else
:
# 2b. Attention decoder branch
if
self
.
ctc_weight
!=
1.0
:
loss_att
,
acc_att
,
cer_att
,
wer_att
=
self
.
_calc_att_loss
(
encoder_out
,
encoder_out_lens
,
text
,
text_lengths
)
# 3. CTC-Att loss definition
if
self
.
ctc_weight
==
0.0
:
loss
=
loss_att
elif
self
.
ctc_weight
==
1.0
:
loss
=
loss_ctc
else
:
loss
=
self
.
ctc_weight
*
loss_ctc
+
(
1
-
self
.
ctc_weight
)
*
loss_att
# Collect Attn branch stats
stats
[
"loss_att"
]
=
loss_att
.
detach
()
if
loss_att
is
not
None
else
None
stats
[
"acc"
]
=
acc_att
stats
[
"cer"
]
=
cer_att
stats
[
"wer"
]
=
wer_att
# Collect total loss stats
stats
[
"loss"
]
=
loss
.
detach
()
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss
,
stats
,
weight
=
force_gatherable
((
loss
,
stats
,
batch_size
),
loss
.
device
)
return
loss
,
stats
,
weight
def
collect_feats
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
text
:
torch
.
Tensor
,
text_lengths
:
torch
.
Tensor
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
feats
,
feats_lengths
=
self
.
_extract_feats
(
speech
,
speech_lengths
)
return
{
"feats"
:
feats
,
"feats_lengths"
:
feats_lengths
}
def
pre_data
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with
autocast
(
False
):
# 1. Extract feats
feats
,
feats_lengths
=
self
.
_extract_feats
(
speech
,
speech_lengths
)
# 2. Data augmentation
if
self
.
specaug
is
not
None
and
self
.
training
:
feats
,
feats_lengths
=
self
.
specaug
(
feats
,
feats_lengths
)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
#print("self.normalize:",self.normalize)
if
self
.
normalize
is
not
None
:
feats
,
feats_lengths
=
self
.
normalize
(
feats
,
feats_lengths
)
# Pre-encoder, e.g. used for raw input data
#if self.preencoder is not None:
# feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
#if self.encoder.interctc_use_conditioning:
# encoder_out, encoder_out_lens, _ = self.encoder(
# feats, feats_lengths, ctc=self.ctc
# )
#else:
# encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
#intermediate_outs = None
#if isinstance(encoder_out, tuple):
# intermediate_outs = encoder_out[1]
# encoder_out = encoder_out[0]
# Post-encoder, e.g. NLU
#if self.postencoder is not None:
# encoder_out, encoder_out_lens = self.postencoder(
# encoder_out, encoder_out_lens
# )
#assert encoder_out.size(0) == speech.size(0), (
# encoder_out.size(),
# speech.size(0),
#)
#if (
# getattr(self.encoder, "selfattention_layer_type", None) != "lf_selfattn"
# and not self.is_encoder_whisper
#):
# assert encoder_out.size(-2) <= encoder_out_lens.max(), (
# encoder_out.size(),
# encoder_out_lens.max(),
# )
#if intermediate_outs is not None:
# return (encoder_out, intermediate_outs), encoder_out_lens
return
feats
,
feats_lengths
def
encode
(
self
,
feats
:
torch
.
Tensor
,
feats_lengths
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
# with autocast(False):
# # 1. Extract feats
# feats, feats_lengths = self._extract_feats(speech, speech_lengths)
# # 2. Data augmentation
# if self.specaug is not None and self.training:
# feats, feats_lengths = self.specaug(feats, feats_lengths)
# # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
# #print("self.normalize:",self.normalize)
# if self.normalize is not None:
# feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
# if self.preencoder is not None:
# feats, feats_lengths = self.preencoder(feats, feats_lengths)
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
if
self
.
encoder
.
interctc_use_conditioning
:
encoder_out
,
encoder_out_lens
,
_
=
self
.
encoder
(
feats
,
feats_lengths
,
ctc
=
self
.
ctc
)
else
:
encoder_out
,
encoder_out_lens
,
_
=
self
.
encoder
(
feats
,
feats_lengths
)
intermediate_outs
=
None
if
isinstance
(
encoder_out
,
tuple
):
intermediate_outs
=
encoder_out
[
1
]
encoder_out
=
encoder_out
[
0
]
# Post-encoder, e.g. NLU
if
self
.
postencoder
is
not
None
:
encoder_out
,
encoder_out_lens
=
self
.
postencoder
(
encoder_out
,
encoder_out_lens
)
assert
encoder_out
.
size
(
0
)
==
feats
.
size
(
0
),
(
encoder_out
.
size
(),
feats
.
size
(
0
),
)
if
(
getattr
(
self
.
encoder
,
"selfattention_layer_type"
,
None
)
!=
"lf_selfattn"
and
not
self
.
is_encoder_whisper
):
assert
encoder_out
.
size
(
-
2
)
<=
encoder_out_lens
.
max
(),
(
encoder_out
.
size
(),
encoder_out_lens
.
max
(),
)
if
intermediate_outs
is
not
None
:
return
(
encoder_out
,
intermediate_outs
),
encoder_out_lens
return
encoder_out
,
encoder_out_lens
def
_extract_feats
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
speech_lengths
.
dim
()
==
1
,
speech_lengths
.
shape
# for data-parallel
speech
=
speech
[:,
:
speech_lengths
.
max
()]
if
self
.
frontend
is
not
None
:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats
,
feats_lengths
=
self
.
frontend
(
speech
,
speech_lengths
)
else
:
# No frontend and no feature extract
feats
,
feats_lengths
=
speech
,
speech_lengths
return
feats
,
feats_lengths
def
nll
(
self
,
encoder_out
:
torch
.
Tensor
,
encoder_out_lens
:
torch
.
Tensor
,
ys_pad
:
torch
.
Tensor
,
ys_pad_lens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Compute negative log likelihood(nll) from transformer-decoder
Normally, this function is called in batchify_nll.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
"""
ys_in_pad
,
ys_out_pad
=
add_sos_eos
(
ys_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
ys_in_lens
=
ys_pad_lens
+
1
# 1. Forward decoder
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_out_lens
,
ys_in_pad
,
ys_in_lens
)
# [batch, seqlen, dim]
batch_size
=
decoder_out
.
size
(
0
)
decoder_num_class
=
decoder_out
.
size
(
2
)
# nll: negative log-likelihood
nll
=
torch
.
nn
.
functional
.
cross_entropy
(
decoder_out
.
view
(
-
1
,
decoder_num_class
),
ys_out_pad
.
view
(
-
1
),
ignore_index
=
self
.
ignore_id
,
reduction
=
"none"
,
)
nll
=
nll
.
view
(
batch_size
,
-
1
)
nll
=
nll
.
sum
(
dim
=
1
)
assert
nll
.
size
(
0
)
==
batch_size
return
nll
def
batchify_nll
(
self
,
encoder_out
:
torch
.
Tensor
,
encoder_out_lens
:
torch
.
Tensor
,
ys_pad
:
torch
.
Tensor
,
ys_pad_lens
:
torch
.
Tensor
,
batch_size
:
int
=
100
,
):
"""Compute negative log likelihood(nll) from transformer-decoder
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
ys_pad: (Batch, Length)
ys_pad_lens: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
GPU memory usage
"""
total_num
=
encoder_out
.
size
(
0
)
if
total_num
<=
batch_size
:
nll
=
self
.
nll
(
encoder_out
,
encoder_out_lens
,
ys_pad
,
ys_pad_lens
)
else
:
nll
=
[]
start_idx
=
0
while
True
:
end_idx
=
min
(
start_idx
+
batch_size
,
total_num
)
batch_encoder_out
=
encoder_out
[
start_idx
:
end_idx
,
:,
:]
batch_encoder_out_lens
=
encoder_out_lens
[
start_idx
:
end_idx
]
batch_ys_pad
=
ys_pad
[
start_idx
:
end_idx
,
:]
batch_ys_pad_lens
=
ys_pad_lens
[
start_idx
:
end_idx
]
batch_nll
=
self
.
nll
(
batch_encoder_out
,
batch_encoder_out_lens
,
batch_ys_pad
,
batch_ys_pad_lens
,
)
nll
.
append
(
batch_nll
)
start_idx
=
end_idx
if
start_idx
==
total_num
:
break
nll
=
torch
.
cat
(
nll
)
assert
nll
.
size
(
0
)
==
total_num
return
nll
def
_calc_att_loss
(
self
,
encoder_out
:
torch
.
Tensor
,
encoder_out_lens
:
torch
.
Tensor
,
ys_pad
:
torch
.
Tensor
,
ys_pad_lens
:
torch
.
Tensor
,
):
if
hasattr
(
self
,
"lang_token_id"
)
and
self
.
lang_token_id
is
not
None
:
ys_pad
=
torch
.
cat
(
[
self
.
lang_token_id
.
repeat
(
ys_pad
.
size
(
0
),
1
).
to
(
ys_pad
.
device
),
ys_pad
,
],
dim
=
1
,
)
ys_pad_lens
+=
1
ys_in_pad
,
ys_out_pad
=
add_sos_eos
(
ys_pad
,
self
.
sos
,
self
.
eos
,
self
.
ignore_id
)
ys_in_lens
=
ys_pad_lens
+
1
# 1. Forward decoder
decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_out_lens
,
ys_in_pad
,
ys_in_lens
)
# 2. Compute attention loss
loss_att
=
self
.
criterion_att
(
decoder_out
,
ys_out_pad
)
acc_att
=
th_accuracy
(
decoder_out
.
view
(
-
1
,
self
.
vocab_size
),
ys_out_pad
,
ignore_label
=
self
.
ignore_id
,
)
# Compute cer/wer using attention-decoder
if
self
.
training
or
self
.
error_calculator
is
None
:
cer_att
,
wer_att
=
None
,
None
else
:
ys_hat
=
decoder_out
.
argmax
(
dim
=-
1
)
cer_att
,
wer_att
=
self
.
error_calculator
(
ys_hat
.
cpu
(),
ys_pad
.
cpu
())
return
loss_att
,
acc_att
,
cer_att
,
wer_att
def
_calc_ctc_loss
(
self
,
encoder_out
:
torch
.
Tensor
,
encoder_out_lens
:
torch
.
Tensor
,
ys_pad
:
torch
.
Tensor
,
ys_pad_lens
:
torch
.
Tensor
,
):
# Calc CTC loss
loss_ctc
=
self
.
ctc
(
encoder_out
,
encoder_out_lens
,
ys_pad
,
ys_pad_lens
)
# Calc CER using CTC
cer_ctc
=
None
if
not
self
.
training
and
self
.
error_calculator
is
not
None
:
ys_hat
=
self
.
ctc
.
argmax
(
encoder_out
).
data
cer_ctc
=
self
.
error_calculator
(
ys_hat
.
cpu
(),
ys_pad
.
cpu
(),
is_ctc
=
True
)
return
loss_ctc
,
cer_ctc
def
_calc_transducer_loss
(
self
,
encoder_out
:
torch
.
Tensor
,
encoder_out_lens
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
,
):
"""Compute Transducer loss.
Args:
encoder_out: Encoder output sequences. (B, T, D_enc)
encoder_out_lens: Encoder output sequences lengths. (B,)
labels: Label ID sequences. (B, L)
Return:
loss_transducer: Transducer loss value.
cer_transducer: Character error rate for Transducer.
wer_transducer: Word Error Rate for Transducer.
"""
decoder_in
,
target
,
t_len
,
u_len
=
get_transducer_task_io
(
labels
,
encoder_out_lens
,
ignore_id
=
self
.
ignore_id
,
blank_id
=
self
.
blank_id
,
)
self
.
decoder
.
set_device
(
encoder_out
.
device
)
decoder_out
=
self
.
decoder
(
decoder_in
)
joint_out
=
self
.
joint_network
(
encoder_out
.
unsqueeze
(
2
),
decoder_out
.
unsqueeze
(
1
)
)
loss_transducer
=
self
.
criterion_transducer
(
joint_out
,
target
,
t_len
,
u_len
,
)
cer_transducer
,
wer_transducer
=
None
,
None
if
not
self
.
training
and
self
.
error_calculator_trans
is
not
None
:
cer_transducer
,
wer_transducer
=
self
.
error_calculator_trans
(
encoder_out
,
target
)
return
loss_transducer
,
cer_transducer
,
wer_transducer
def
_calc_batch_ctc_loss
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,
text
:
torch
.
Tensor
,
text_lengths
:
torch
.
Tensor
,
):
if
self
.
ctc
is
None
:
return
assert
text_lengths
.
dim
()
==
1
,
text_lengths
.
shape
# Check that batch_size is unified
assert
(
speech
.
shape
[
0
]
==
speech_lengths
.
shape
[
0
]
==
text
.
shape
[
0
]
==
text_lengths
.
shape
[
0
]
),
(
speech
.
shape
,
speech_lengths
.
shape
,
text
.
shape
,
text_lengths
.
shape
)
# for data-parallel
text
=
text
[:,
:
text_lengths
.
max
()]
# 1. Encoder
encoder_out
,
encoder_out_lens
=
self
.
encode
(
speech
,
speech_lengths
)
if
isinstance
(
encoder_out
,
tuple
):
encoder_out
=
encoder_out
[
0
]
# Calc CTC loss
do_reduce
=
self
.
ctc
.
reduce
self
.
ctc
.
reduce
=
False
loss_ctc
=
self
.
ctc
(
encoder_out
,
encoder_out_lens
,
text
,
text_lengths
)
self
.
ctc
.
reduce
=
do_reduce
return
loss_ctc
meta.yaml
0 → 100644
View file @
0941998c
espnet
:
0.9.0
files
:
asr_model_file
:
exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/valid.acc.ave_10best.pth
lm_file
:
exp/lm_train_lm_transformer_char_batch_bins2000000/valid.loss.ave_10best.pth
python
:
"
3.7.3
(default,
Mar
27
2019,
22:11:17)
\n
[GCC
7.3.0]"
timestamp
:
1603088092.704853
torch
:
1.6.0
yaml_files
:
asr_train_config
:
exp/asr_train_asr_conformer3_raw_char_batch_bins4000000_accum_grad4_sp/config.yaml
lm_train_config
:
exp/lm_train_lm_transformer_char_batch_bins2000000/config.yaml
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment