Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Conformer_pytorch
Commits
a7785cc6
Commit
a7785cc6
authored
Mar 26, 2024
by
Sugon_ldc
Browse files
delete soft link
parent
9a2a05ca
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5661 additions
and
1 deletion
+5661
-1
examples/aishell/s0/tools/websocket/performance-ws.py
examples/aishell/s0/tools/websocket/performance-ws.py
+166
-0
examples/aishell/s0/wenet
examples/aishell/s0/wenet
+0
-1
examples/aishell/s0/wenet/bin/alignment.py
examples/aishell/s0/wenet/bin/alignment.py
+235
-0
examples/aishell/s0/wenet/bin/average_model.py
examples/aishell/s0/wenet/bin/average_model.py
+101
-0
examples/aishell/s0/wenet/bin/export_jit.py
examples/aishell/s0/wenet/bin/export_jit.py
+70
-0
examples/aishell/s0/wenet/bin/export_onnx_bpu.py
examples/aishell/s0/wenet/bin/export_onnx_bpu.py
+1019
-0
examples/aishell/s0/wenet/bin/export_onnx_cpu.py
examples/aishell/s0/wenet/bin/export_onnx_cpu.py
+411
-0
examples/aishell/s0/wenet/bin/export_onnx_gpu.py
examples/aishell/s0/wenet/bin/export_onnx_gpu.py
+824
-0
examples/aishell/s0/wenet/bin/recognize.py
examples/aishell/s0/wenet/bin/recognize.py
+360
-0
examples/aishell/s0/wenet/bin/recognize_onnx_gpu.py
examples/aishell/s0/wenet/bin/recognize_onnx_gpu.py
+278
-0
examples/aishell/s0/wenet/bin/train.py
examples/aishell/s0/wenet/bin/train.py
+372
-0
examples/aishell/s0/wenet/dataset/__pycache__/dataset.cpython-38.pyc
...shell/s0/wenet/dataset/__pycache__/dataset.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/dataset/__pycache__/processor.cpython-38.pyc
...ell/s0/wenet/dataset/__pycache__/processor.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/dataset/dataset.py
examples/aishell/s0/wenet/dataset/dataset.py
+193
-0
examples/aishell/s0/wenet/dataset/kaldi_io.py
examples/aishell/s0/wenet/dataset/kaldi_io.py
+666
-0
examples/aishell/s0/wenet/dataset/processor.py
examples/aishell/s0/wenet/dataset/processor.py
+642
-0
examples/aishell/s0/wenet/dataset/wav_distortion.py
examples/aishell/s0/wenet/dataset/wav_distortion.py
+324
-0
examples/aishell/s0/wenet/efficient_conformer/__pycache__/attention.cpython-38.pyc
.../efficient_conformer/__pycache__/attention.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/efficient_conformer/__pycache__/convolution.cpython-38.pyc
...fficient_conformer/__pycache__/convolution.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/efficient_conformer/__pycache__/encoder.cpython-38.pyc
...et/efficient_conformer/__pycache__/encoder.cpython-38.pyc
+0
-0
No files found.
examples/aishell/s0/tools/websocket/performance-ws.py
0 → 100755
View file @
a7785cc6
#!/usr/bin/env python3
# coding:utf-8
# Copyright (c) 2022 SDCI Co. Ltd (author: veelion)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
json
import
time
import
asyncio
import
argparse
import
websockets
import
soundfile
as
sf
import
statistics
WS_START
=
json
.
dumps
({
'signal'
:
'start'
,
'nbest'
:
1
,
'continuous_decoding'
:
False
,
})
WS_END
=
json
.
dumps
({
'signal'
:
'end'
})
async
def
ws_rec
(
data
,
ws_uri
):
begin
=
time
.
time
()
conn
=
await
websockets
.
connect
(
ws_uri
,
ping_timeout
=
200
)
# step 1: send start
await
conn
.
send
(
WS_START
)
ret
=
await
conn
.
recv
()
# step 2: send audio data
await
conn
.
send
(
data
)
# step 3: send end
await
conn
.
send
(
WS_END
)
# step 4: receive result
texts
=
[]
while
1
:
ret
=
await
conn
.
recv
()
ret
=
json
.
loads
(
ret
)
if
ret
[
'type'
]
==
'final_result'
:
nbest
=
json
.
loads
(
ret
[
'nbest'
])
text
=
nbest
[
0
][
'sentence'
]
texts
.
append
(
text
)
elif
ret
[
'type'
]
==
'speech_end'
:
break
# step 5: close
try
:
await
conn
.
close
()
except
Exception
as
e
:
# this except has no effect, just log as debug
# it seems the server does not send close info, maybe
print
(
e
)
time_cost
=
time
.
time
()
-
begin
return
{
'text'
:
''
.
join
(
texts
),
'time'
:
time_cost
,
}
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
''
)
parser
.
add_argument
(
'-u'
,
'--ws_uri'
,
required
=
True
,
help
=
"websocket_server_main's uri, e.g. ws://127.0.0.1:10086"
)
parser
.
add_argument
(
'-w'
,
'--wav_scp'
,
required
=
True
,
help
=
'path to wav_scp_file'
)
parser
.
add_argument
(
'-t'
,
'--trans'
,
required
=
True
,
help
=
'path to trans_text_file of wavs'
)
parser
.
add_argument
(
'-s'
,
'--save_to'
,
required
=
True
,
help
=
'path to save transcription'
)
parser
.
add_argument
(
'-n'
,
'--num_concurrence'
,
type
=
int
,
required
=
True
,
help
=
'num of concurrence for query'
)
args
=
parser
.
parse_args
()
return
args
def
print_result
(
info
):
length
=
max
([
len
(
k
)
for
k
in
info
])
for
k
,
v
in
info
.
items
():
print
(
f
'
\t
{
k
:
>
{
length
}}
:
{
v
}
'
)
async
def
main
(
args
):
wav_scp
=
[]
total_duration
=
0
with
open
(
args
.
wav_scp
)
as
f
:
for
line
in
f
:
zz
=
line
.
strip
().
split
()
assert
len
(
zz
)
==
2
data
,
sr
=
sf
.
read
(
zz
[
1
],
dtype
=
'int16'
)
assert
sr
==
16000
duration
=
(
len
(
data
))
/
16000
total_duration
+=
duration
wav_scp
.
append
((
zz
[
0
],
data
.
tobytes
()))
print
(
f
'
{
len
(
wav_scp
)
=
}
,
{
total_duration
=
}
'
)
tasks
=
[]
failed
=
0
texts
=
[]
request_times
=
[]
begin
=
time
.
time
()
for
i
,
(
_uttid
,
data
)
in
enumerate
(
wav_scp
):
task
=
asyncio
.
create_task
(
ws_rec
(
data
,
args
.
ws_uri
))
tasks
.
append
((
_uttid
,
task
))
if
len
(
tasks
)
<
args
.
num_concurrence
:
continue
print
((
f
'
{
i
=
}
, start
{
args
.
num_concurrence
}
'
f
'queries @
{
time
.
strftime
(
"%m-%d %H:%M:%S"
)
}
'
))
for
uttid
,
task
in
tasks
:
result
=
await
task
texts
.
append
(
f
'
{
uttid
}
\t
{
result
[
"text"
]
}
\n
'
)
request_times
.
append
(
result
[
'time'
])
tasks
=
[]
print
(
f
'
\t
done @
{
time
.
strftime
(
"%m-%d %H:%M:%S"
)
}
'
)
if
tasks
:
for
uttid
,
task
in
tasks
:
result
=
await
task
texts
.
append
(
f
'
{
uttid
}
\t
{
result
[
"text"
]
}
\n
'
)
request_times
.
append
(
result
[
'time'
])
request_time
=
time
.
time
()
-
begin
rtf
=
request_time
/
total_duration
print
(
'For all concurrence:'
)
print_result
({
'failed'
:
failed
,
'total_duration'
:
total_duration
,
'request_time'
:
request_time
,
'RTF'
:
rtf
,
})
print
(
'For one request:'
)
print_result
({
'mean'
:
statistics
.
mean
(
request_times
),
'median'
:
statistics
.
median
(
request_times
),
'max_time'
:
max
(
request_times
),
'min_time'
:
min
(
request_times
),
})
with
open
(
args
.
save_to
,
'w'
,
encoding
=
'utf8'
)
as
fsave
:
fsave
.
write
(
''
.
join
(
texts
))
# caculate CER
cmd
=
(
f
'python ../compute-wer.py --char=1 --v=1 '
f
'
{
args
.
trans
}
{
args
.
save_to
}
> '
f
'
{
args
.
save_to
}
-test-
{
args
.
num_concurrence
}
.cer.txt'
)
print
(
cmd
)
os
.
system
(
cmd
)
print
(
'done'
)
if
__name__
==
'__main__'
:
args
=
get_args
()
asyncio
.
run
(
main
(
args
))
examples/aishell/s0/wenet
deleted
120000 → 0
View file @
9a2a05ca
../../../wenet/
\ No newline at end of file
examples/aishell/s0/wenet/bin/alignment.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc. (authors: Di Wu)
# 2022 Tinnove Inc (authors: Wei Ren)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
argparse
import
copy
import
logging
import
os
import
sys
import
torch
import
yaml
from
torch.utils.data
import
DataLoader
from
textgrid
import
TextGrid
,
IntervalTier
from
wenet.dataset.dataset
import
Dataset
from
wenet.utils.checkpoint
import
load_checkpoint
from
wenet.utils.file_utils
import
read_symbol_table
,
read_non_lang_symbols
from
wenet.utils.ctc_util
import
forced_align
from
wenet.utils.common
import
get_subsample
from
wenet.utils.init_model
import
init_model
def
generator_textgrid
(
maxtime
,
lines
,
output
):
# Download Praat: https://www.fon.hum.uva.nl/praat/
interval
=
maxtime
/
(
len
(
lines
)
+
1
)
margin
=
0.0001
tg
=
TextGrid
(
maxTime
=
maxtime
)
linetier
=
IntervalTier
(
name
=
"line"
,
maxTime
=
maxtime
)
i
=
0
for
l
in
lines
:
s
,
e
,
w
=
l
.
split
()
linetier
.
add
(
minTime
=
float
(
s
)
+
margin
,
maxTime
=
float
(
e
),
mark
=
w
)
tg
.
append
(
linetier
)
print
(
"successfully generator {}"
.
format
(
output
))
tg
.
write
(
output
)
def
get_frames_timestamp
(
alignment
):
# convert alignment to a praat format, which is a doing phonetics
# by computer and helps analyzing alignment
timestamp
=
[]
# get frames level duration for each token
start
=
0
end
=
0
while
end
<
len
(
alignment
):
while
end
<
len
(
alignment
)
and
alignment
[
end
]
==
0
:
end
+=
1
if
end
==
len
(
alignment
):
timestamp
[
-
1
]
+=
alignment
[
start
:]
break
end
+=
1
while
end
<
len
(
alignment
)
and
alignment
[
end
-
1
]
==
alignment
[
end
]:
end
+=
1
timestamp
.
append
(
alignment
[
start
:
end
])
start
=
end
return
timestamp
def
get_labformat
(
timestamp
,
subsample
):
begin
=
0
duration
=
0
labformat
=
[]
for
idx
,
t
in
enumerate
(
timestamp
):
# 25ms frame_length,10ms hop_length, 1/subsample
subsample
=
get_subsample
(
configs
)
# time duration
duration
=
len
(
t
)
*
0.01
*
subsample
if
idx
<
len
(
timestamp
)
-
1
:
print
(
"{:.2f} {:.2f} {}"
.
format
(
begin
,
begin
+
duration
,
char_dict
[
t
[
-
1
]]))
labformat
.
append
(
"{:.2f} {:.2f} {}
\n
"
.
format
(
begin
,
begin
+
duration
,
char_dict
[
t
[
-
1
]]))
else
:
non_blank
=
0
for
i
in
t
:
if
i
!=
0
:
token
=
i
break
print
(
"{:.2f} {:.2f} {}"
.
format
(
begin
,
begin
+
duration
,
char_dict
[
token
]))
labformat
.
append
(
"{:.2f} {:.2f} {}
\n
"
.
format
(
begin
,
begin
+
duration
,
char_dict
[
token
]))
begin
=
begin
+
duration
return
labformat
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'use ctc to generate alignment'
)
parser
.
add_argument
(
'--config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--input_file'
,
required
=
True
,
help
=
'format data file'
)
parser
.
add_argument
(
'--data_type'
,
default
=
'raw'
,
choices
=
[
'raw'
,
'shard'
],
help
=
'train and cv data type'
)
parser
.
add_argument
(
'--gpu'
,
type
=
int
,
default
=-
1
,
help
=
'gpu id for this rank, -1 for cpu'
)
parser
.
add_argument
(
'--checkpoint'
,
required
=
True
,
help
=
'checkpoint model'
)
parser
.
add_argument
(
'--dict'
,
required
=
True
,
help
=
'dict file'
)
parser
.
add_argument
(
'--non_lang_syms'
,
help
=
"non-linguistic symbol file. One symbol per line."
)
parser
.
add_argument
(
'--result_file'
,
required
=
True
,
help
=
'alignment result file'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
'--gen_praat'
,
action
=
'store_true'
,
help
=
'convert alignment to a praat format'
)
parser
.
add_argument
(
'--bpe_model'
,
default
=
None
,
type
=
str
,
help
=
'bpe model for english part'
)
args
=
parser
.
parse_args
()
print
(
args
)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
format
=
'%(asctime)s %(levelname)s %(message)s'
)
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu
)
if
args
.
batch_size
>
1
:
logging
.
fatal
(
'alignment mode must be running with batch_size == 1'
)
sys
.
exit
(
1
)
with
open
(
args
.
config
,
'r'
)
as
fin
:
configs
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
FullLoader
)
# Load dict
char_dict
=
{}
with
open
(
args
.
dict
,
'r'
)
as
fin
:
for
line
in
fin
:
arr
=
line
.
strip
().
split
()
assert
len
(
arr
)
==
2
char_dict
[
int
(
arr
[
1
])]
=
arr
[
0
]
eos
=
len
(
char_dict
)
-
1
symbol_table
=
read_symbol_table
(
args
.
dict
)
# Init dataset and data loader
ali_conf
=
copy
.
deepcopy
(
configs
[
'dataset_conf'
])
ali_conf
[
'filter_conf'
][
'max_length'
]
=
102400
ali_conf
[
'filter_conf'
][
'min_length'
]
=
0
ali_conf
[
'filter_conf'
][
'token_max_length'
]
=
102400
ali_conf
[
'filter_conf'
][
'token_min_length'
]
=
0
ali_conf
[
'filter_conf'
][
'max_output_input_ratio'
]
=
102400
ali_conf
[
'filter_conf'
][
'min_output_input_ratio'
]
=
0
ali_conf
[
'speed_perturb'
]
=
False
ali_conf
[
'spec_aug'
]
=
False
ali_conf
[
'shuffle'
]
=
False
ali_conf
[
'sort'
]
=
False
ali_conf
[
'fbank_conf'
][
'dither'
]
=
0.0
ali_conf
[
'batch_conf'
][
'batch_type'
]
=
"static"
ali_conf
[
'batch_conf'
][
'batch_size'
]
=
args
.
batch_size
non_lang_syms
=
read_non_lang_symbols
(
args
.
non_lang_syms
)
ali_dataset
=
Dataset
(
args
.
data_type
,
args
.
input_file
,
symbol_table
,
ali_conf
,
args
.
bpe_model
,
non_lang_syms
,
partition
=
False
)
ali_data_loader
=
DataLoader
(
ali_dataset
,
batch_size
=
None
,
num_workers
=
0
)
# Init asr model from configs
model
=
init_model
(
configs
)
load_checkpoint
(
model
,
args
.
checkpoint
)
use_cuda
=
args
.
gpu
>=
0
and
torch
.
cuda
.
is_available
()
device
=
torch
.
device
(
'cuda'
if
use_cuda
else
'cpu'
)
model
=
model
.
to
(
device
)
model
.
eval
()
with
torch
.
no_grad
(),
open
(
args
.
result_file
,
'w'
,
encoding
=
'utf-8'
)
as
fout
:
for
batch_idx
,
batch
in
enumerate
(
ali_data_loader
):
print
(
"#"
*
80
)
key
,
feat
,
target
,
feats_length
,
target_length
=
batch
print
(
key
)
feat
=
feat
.
to
(
device
)
target
=
target
.
to
(
device
)
feats_length
=
feats_length
.
to
(
device
)
target_length
=
target_length
.
to
(
device
)
# Let's assume B = batch_size and N = beam_size
# 1. Encoder
encoder_out
,
encoder_mask
=
model
.
_forward_encoder
(
feat
,
feats_length
)
# (B, maxlen, encoder_dim)
maxlen
=
encoder_out
.
size
(
1
)
ctc_probs
=
model
.
ctc
.
log_softmax
(
encoder_out
)
# (1, maxlen, vocab_size)
# print(ctc_probs.size(1))
ctc_probs
=
ctc_probs
.
squeeze
(
0
)
target
=
target
.
squeeze
(
0
)
alignment
=
forced_align
(
ctc_probs
,
target
)
print
(
alignment
)
fout
.
write
(
'{} {}
\n
'
.
format
(
key
[
0
],
alignment
))
if
args
.
gen_praat
:
timestamp
=
get_frames_timestamp
(
alignment
)
print
(
timestamp
)
subsample
=
get_subsample
(
configs
)
labformat
=
get_labformat
(
timestamp
,
subsample
)
lab_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
args
.
result_file
),
key
[
0
]
+
".lab"
)
with
open
(
lab_path
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
writelines
(
labformat
)
textgrid_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
args
.
result_file
),
key
[
0
]
+
".TextGrid"
)
generator_textgrid
(
maxtime
=
(
len
(
alignment
)
+
1
)
*
0.01
*
subsample
,
lines
=
labformat
,
output
=
textgrid_path
)
examples/aishell/s0/wenet/bin/average_model.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
import
glob
import
yaml
import
numpy
as
np
import
torch
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'average model'
)
parser
.
add_argument
(
'--dst_model'
,
required
=
True
,
help
=
'averaged model'
)
parser
.
add_argument
(
'--src_path'
,
required
=
True
,
help
=
'src model path for average'
)
parser
.
add_argument
(
'--val_best'
,
action
=
"store_true"
,
help
=
'averaged model'
)
parser
.
add_argument
(
'--num'
,
default
=
5
,
type
=
int
,
help
=
'nums for averaged model'
)
parser
.
add_argument
(
'--min_epoch'
,
default
=
0
,
type
=
int
,
help
=
'min epoch used for averaging model'
)
parser
.
add_argument
(
'--max_epoch'
,
default
=
65536
,
type
=
int
,
help
=
'max epoch used for averaging model'
)
args
=
parser
.
parse_args
()
print
(
args
)
return
args
def
main
():
args
=
get_args
()
checkpoints
=
[]
val_scores
=
[]
if
args
.
val_best
:
yamls
=
glob
.
glob
(
'{}/[!train]*.yaml'
.
format
(
args
.
src_path
))
for
y
in
yamls
:
with
open
(
y
,
'r'
)
as
f
:
dic_yaml
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
loss
=
dic_yaml
[
'cv_loss'
]
epoch
=
dic_yaml
[
'epoch'
]
if
epoch
>=
args
.
min_epoch
and
epoch
<=
args
.
max_epoch
:
val_scores
+=
[[
epoch
,
loss
]]
val_scores
=
np
.
array
(
val_scores
)
sort_idx
=
np
.
argsort
(
val_scores
[:,
-
1
])
sorted_val_scores
=
val_scores
[
sort_idx
][::
1
]
print
(
"best val scores = "
+
str
(
sorted_val_scores
[:
args
.
num
,
1
]))
print
(
"selected epochs = "
+
str
(
sorted_val_scores
[:
args
.
num
,
0
].
astype
(
np
.
int64
)))
path_list
=
[
args
.
src_path
+
'/{}.pt'
.
format
(
int
(
epoch
))
for
epoch
in
sorted_val_scores
[:
args
.
num
,
0
]
]
else
:
path_list
=
glob
.
glob
(
'{}/[0-9]*.pt'
.
format
(
args
.
src_path
))
path_list
=
sorted
(
path_list
,
key
=
os
.
path
.
getmtime
)
path_list
=
path_list
[
-
args
.
num
:]
print
(
path_list
)
avg
=
None
num
=
args
.
num
assert
num
==
len
(
path_list
)
for
path
in
path_list
:
print
(
'Processing {}'
.
format
(
path
))
states
=
torch
.
load
(
path
,
map_location
=
torch
.
device
(
'cpu'
))
if
avg
is
None
:
avg
=
states
else
:
for
k
in
avg
.
keys
():
avg
[
k
]
+=
states
[
k
]
# average
for
k
in
avg
.
keys
():
if
avg
[
k
]
is
not
None
:
# pytorch 1.6 use true_divide instead of /=
avg
[
k
]
=
torch
.
true_divide
(
avg
[
k
],
num
)
print
(
'Saving to {}'
.
format
(
args
.
dst_model
))
torch
.
save
(
avg
,
args
.
dst_model
)
if
__name__
==
'__main__'
:
main
()
examples/aishell/s0/wenet/bin/export_jit.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
argparse
import
os
import
torch
import
yaml
from
wenet.utils.checkpoint
import
load_checkpoint
from
wenet.utils.init_model
import
init_model
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'export your script model'
)
parser
.
add_argument
(
'--config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--checkpoint'
,
required
=
True
,
help
=
'checkpoint model'
)
parser
.
add_argument
(
'--output_file'
,
default
=
None
,
help
=
'output file'
)
parser
.
add_argument
(
'--output_quant_file'
,
default
=
None
,
help
=
'output quantized model file'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
get_args
()
# No need gpu for model export
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'-1'
with
open
(
args
.
config
,
'r'
)
as
fin
:
configs
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
FullLoader
)
model
=
init_model
(
configs
)
print
(
model
)
load_checkpoint
(
model
,
args
.
checkpoint
)
# Export jit torch script model
if
args
.
output_file
:
script_model
=
torch
.
jit
.
script
(
model
)
script_model
.
save
(
args
.
output_file
)
print
(
'Export model successfully, see {}'
.
format
(
args
.
output_file
))
# Export quantized jit torch script model
if
args
.
output_quant_file
:
quantized_model
=
torch
.
quantization
.
quantize_dynamic
(
model
,
{
torch
.
nn
.
Linear
},
dtype
=
torch
.
qint8
)
print
(
quantized_model
)
script_quant_model
=
torch
.
jit
.
script
(
quantized_model
)
script_quant_model
.
save
(
args
.
output_quant_file
)
print
(
'Export quantized model successfully, '
'see {}'
.
format
(
args
.
output_quant_file
))
if
__name__
==
'__main__'
:
main
()
examples/aishell/s0/wenet/bin/export_onnx_bpu.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2022, Horizon Inc. Xingchen Song (sxc19@tsinghua.org.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""NOTE(xcsong): Currently, we only support
1. specific conformer encoder architecture, see:
encoder: conformer
encoder_conf:
activation_type: **must be** relu
attention_heads: 2 or 4 or 8 or any number divisible by output_size
causal: **must be** true
cnn_module_kernel: 1 ~ 7
cnn_module_norm: **must be** batch_norm
input_layer: **must be** conv2d8
linear_units: 1 ~ 2048
normalize_before: **must be** true
num_blocks: 1 ~ 12
output_size: 1 ~ 512
pos_enc_layer_type: **must be** no_pos
selfattention_layer_type: **must be** selfattn
use_cnn_module: **must be** true
use_dynamic_chunk: **must be** true
use_dynamic_left_chunk: **must be** true
2. specific decoding method: ctc_greedy_search
"""
from
__future__
import
print_function
import
os
import
sys
import
copy
import
math
import
yaml
import
logging
from
typing
import
Tuple
import
torch
import
numpy
as
np
from
wenet.transformer.embedding
import
NoPositionalEncoding
from
wenet.utils.checkpoint
import
load_checkpoint
from
wenet.utils.init_model
import
init_model
from
wenet.bin.export_onnx_cpu
import
(
get_args
,
to_numpy
,
print_input_output_info
)
try
:
import
onnx
import
onnxruntime
except
ImportError
:
print
(
'Please install onnx and onnxruntime!'
)
sys
.
exit
(
1
)
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
logging
.
INFO
)
class
BPULayerNorm
(
torch
.
nn
.
Module
):
"""Refactor torch.nn.LayerNorm to meet 4-D dataflow."""
def
__init__
(
self
,
module
,
chunk_size
=
8
,
run_on_bpu
=
False
):
super
().
__init__
()
original
=
copy
.
deepcopy
(
module
)
self
.
hidden
=
module
.
weight
.
size
(
0
)
self
.
chunk_size
=
chunk_size
self
.
run_on_bpu
=
run_on_bpu
if
self
.
run_on_bpu
:
self
.
weight
=
torch
.
nn
.
Parameter
(
module
.
weight
.
reshape
(
1
,
self
.
hidden
,
1
,
1
).
repeat
(
1
,
1
,
1
,
chunk_size
))
self
.
bias
=
torch
.
nn
.
Parameter
(
module
.
bias
.
reshape
(
1
,
self
.
hidden
,
1
,
1
).
repeat
(
1
,
1
,
1
,
chunk_size
))
self
.
negtive
=
torch
.
nn
.
Parameter
(
torch
.
ones
((
1
,
self
.
hidden
,
1
,
chunk_size
))
*
-
1.0
)
self
.
eps
=
torch
.
nn
.
Parameter
(
torch
.
zeros
((
1
,
self
.
hidden
,
1
,
chunk_size
))
+
module
.
eps
)
self
.
mean_conv_1
=
torch
.
nn
.
Conv2d
(
self
.
hidden
,
1
,
1
,
bias
=
False
)
self
.
mean_conv_1
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
self
.
hidden
,
self
.
hidden
,
1
,
1
)
/
(
1.0
*
self
.
hidden
))
self
.
mean_conv_2
=
torch
.
nn
.
Conv2d
(
self
.
hidden
,
1
,
1
,
bias
=
False
)
self
.
mean_conv_2
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
self
.
hidden
,
self
.
hidden
,
1
,
1
)
/
(
1.0
*
self
.
hidden
))
else
:
self
.
norm
=
module
self
.
check_equal
(
original
)
def
check_equal
(
self
,
module
):
random_data
=
torch
.
randn
(
1
,
self
.
chunk_size
,
self
.
hidden
)
orig_out
=
module
(
random_data
)
new_out
=
self
.
forward
(
random_data
.
transpose
(
1
,
2
).
unsqueeze
(
2
))
np
.
testing
.
assert_allclose
(
to_numpy
(
orig_out
),
to_numpy
(
new_out
.
squeeze
(
2
).
transpose
(
1
,
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
run_on_bpu
:
u
=
self
.
mean_conv_1
(
x
)
# (1, h, 1, c)
numerator
=
x
+
u
*
self
.
negtive
# (1, h, 1, c)
s
=
torch
.
pow
(
numerator
,
2
)
# (1, h, 1, c)
s
=
self
.
mean_conv_2
(
s
)
# (1, h, 1, c)
denominator
=
torch
.
sqrt
(
s
+
self
.
eps
)
# (1, h, 1, c)
x
=
torch
.
div
(
numerator
,
denominator
)
# (1, h, 1, c)
x
=
x
*
self
.
weight
+
self
.
bias
else
:
x
=
x
.
squeeze
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
self
.
norm
(
x
)
x
=
x
.
transpose
(
1
,
2
).
contiguous
().
unsqueeze
(
2
)
return
x
class
BPUIdentity
(
torch
.
nn
.
Module
):
"""Refactor torch.nn.Identity().
For inserting BPU node whose input == output.
"""
def
__init__
(
self
,
channels
):
super
().
__init__
()
self
.
channels
=
channels
self
.
identity_conv
=
torch
.
nn
.
Conv2d
(
channels
,
channels
,
1
,
groups
=
channels
,
bias
=
False
)
torch
.
nn
.
init
.
dirac_
(
self
.
identity_conv
.
weight
.
data
,
groups
=
channels
)
self
.
check_equal
()
def
check_equal
(
self
):
random_data
=
torch
.
randn
(
1
,
self
.
channels
,
1
,
10
)
result
=
self
.
forward
(
random_data
)
np
.
testing
.
assert_allclose
(
to_numpy
(
random_data
),
to_numpy
(
result
),
rtol
=
1e-02
,
atol
=
1e-03
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Identity with 4-D dataflow, input == output.
Args:
x (torch.Tensor): (batch, in_channel, 1, time)
Returns:
(torch.Tensor): (batch, in_channel, 1, time).
"""
return
self
.
identity_conv
(
x
)
class
BPULinear
(
torch
.
nn
.
Module
):
"""Refactor torch.nn.Linear or pointwise_conv"""
def
__init__
(
self
,
module
,
is_pointwise_conv
=
False
):
super
().
__init__
()
# Unchanged submodules and attributes
original
=
copy
.
deepcopy
(
module
)
self
.
idim
=
module
.
weight
.
size
(
1
)
self
.
odim
=
module
.
weight
.
size
(
0
)
self
.
is_pointwise_conv
=
is_pointwise_conv
# Modify weight & bias
self
.
linear
=
torch
.
nn
.
Conv2d
(
self
.
idim
,
self
.
odim
,
1
,
1
)
if
is_pointwise_conv
:
# (odim, idim, kernel=1) -> (odim, idim, 1, 1)
self
.
linear
.
weight
=
torch
.
nn
.
Parameter
(
module
.
weight
.
unsqueeze
(
-
1
))
else
:
# (odim, idim) -> (odim, idim, 1, 1)
self
.
linear
.
weight
=
torch
.
nn
.
Parameter
(
module
.
weight
.
unsqueeze
(
2
).
unsqueeze
(
3
))
self
.
linear
.
bias
=
module
.
bias
self
.
check_equal
(
original
)
def
check_equal
(
self
,
module
):
random_data
=
torch
.
randn
(
1
,
8
,
self
.
idim
)
if
self
.
is_pointwise_conv
:
random_data
=
random_data
.
transpose
(
1
,
2
)
original_result
=
module
(
random_data
)
if
self
.
is_pointwise_conv
:
random_data
=
random_data
.
transpose
(
1
,
2
)
original_result
=
original_result
.
transpose
(
1
,
2
)
random_data
=
random_data
.
transpose
(
1
,
2
).
unsqueeze
(
2
)
new_result
=
self
.
forward
(
random_data
)
np
.
testing
.
assert_allclose
(
to_numpy
(
original_result
),
to_numpy
(
new_result
.
squeeze
(
2
).
transpose
(
1
,
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Linear with 4-D dataflow.
Args:
x (torch.Tensor): (batch, in_channel, 1, time)
Returns:
(torch.Tensor): (batch, out_channel, 1, time).
"""
return
self
.
linear
(
x
)
class
BPUGlobalCMVN
(
torch
.
nn
.
Module
):
"""Refactor wenet/transformer/cmvn.py::GlobalCMVN"""
def
__init__
(
self
,
module
):
super
().
__init__
()
# Unchanged submodules and attributes
self
.
norm_var
=
module
.
norm_var
# NOTE(xcsong): Expand to 4-D tensor, (mel_dim) -> (1, 1, mel_dim, 1)
self
.
mean
=
module
.
mean
.
unsqueeze
(
-
1
).
unsqueeze
(
0
).
unsqueeze
(
0
)
self
.
istd
=
module
.
istd
.
unsqueeze
(
-
1
).
unsqueeze
(
0
).
unsqueeze
(
0
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""CMVN with 4-D dataflow.
Args:
x (torch.Tensor): (batch, 1, mel_dim, time)
Returns:
(torch.Tensor): normalized feature with same shape.
"""
x
=
x
-
self
.
mean
if
self
.
norm_var
:
x
=
x
*
self
.
istd
return
x
class
BPUConv2dSubsampling8
(
torch
.
nn
.
Module
):
"""Refactor wenet/transformer/subsampling.py::Conv2dSubsampling8
NOTE(xcsong): Only support pos_enc_class == NoPositionalEncoding
"""
def
__init__
(
self
,
module
):
super
().
__init__
()
# Unchanged submodules and attributes
original
=
copy
.
deepcopy
(
module
)
self
.
right_context
=
module
.
right_context
self
.
subsampling_rate
=
module
.
subsampling_rate
assert
isinstance
(
module
.
pos_enc
,
NoPositionalEncoding
)
# 1. Modify self.conv
# NOTE(xcsong): We change input shape from (1, 1, frames, mel_dim)
# to (1, 1, mel_dim, frames) for more efficient computation.
self
.
conv
=
module
.
conv
for
idx
in
[
0
,
2
,
4
]:
self
.
conv
[
idx
].
weight
=
torch
.
nn
.
Parameter
(
module
.
conv
[
idx
].
weight
.
transpose
(
2
,
3
)
)
# 2. Modify self.linear
# NOTE(xcsong): Split final projection to meet the requirment of
# maximum kernel_size (7 for XJ3)
self
.
linear
=
torch
.
nn
.
ModuleList
()
odim
=
module
.
linear
.
weight
.
size
(
0
)
# 512, in this case
freq
=
module
.
linear
.
weight
.
size
(
1
)
//
odim
# 4608 // 512 == 9
self
.
odim
,
self
.
freq
=
odim
,
freq
weight
=
module
.
linear
.
weight
.
reshape
(
odim
,
odim
,
freq
,
1
)
# (odim, odim * freq) -> (odim, odim, freq, 1)
self
.
split_size
=
[]
num_split
=
(
freq
-
1
)
//
7
+
1
# XJ3 requires kernel_size <= 7
slice_begin
=
0
for
idx
in
range
(
num_split
):
kernel_size
=
min
(
freq
,
(
idx
+
1
)
*
7
)
-
idx
*
7
conv_ele
=
torch
.
nn
.
Conv2d
(
odim
,
odim
,
(
kernel_size
,
1
),
(
kernel_size
,
1
))
conv_ele
.
weight
=
torch
.
nn
.
Parameter
(
weight
[:,
:,
slice_begin
:
slice_begin
+
kernel_size
,
:]
)
conv_ele
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros_like
(
conv_ele
.
bias
)
)
self
.
linear
.
append
(
conv_ele
)
self
.
split_size
.
append
(
kernel_size
)
slice_begin
+=
kernel_size
self
.
linear
[
0
].
bias
=
torch
.
nn
.
Parameter
(
module
.
linear
.
bias
)
self
.
check_equal
(
original
)
def
check_equal
(
self
,
module
):
random_data
=
torch
.
randn
(
1
,
67
,
80
)
mask
=
torch
.
zeros
(
1
,
1
,
67
)
original_result
,
_
,
_
=
module
(
random_data
,
mask
)
# (1, 8, 512)
random_data
=
random_data
.
transpose
(
1
,
2
).
unsqueeze
(
0
)
# (1, 1, 80, 67)
new_result
=
self
.
forward
(
random_data
)
# (1, 512, 1, 8)
np
.
testing
.
assert_allclose
(
to_numpy
(
original_result
),
to_numpy
(
new_result
.
squeeze
(
2
).
transpose
(
1
,
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Subsample x with 4-D dataflow.
Args:
x (torch.Tensor): Input tensor (#batch, 1, mel_dim, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, odim, 1, time'),
where time' = time // 8.
"""
x
=
self
.
conv
(
x
)
# (1, odim, freq, time')
x_out
=
torch
.
zeros
(
x
.
size
(
0
),
self
.
odim
,
1
,
x
.
size
(
3
))
x
=
torch
.
split
(
x
,
self
.
split_size
,
dim
=
2
)
for
idx
,
(
x_part
,
layer
)
in
enumerate
(
zip
(
x
,
self
.
linear
)):
x_out
+=
layer
(
x_part
)
return
x_out
class
BPUMultiHeadedAttention
(
torch
.
nn
.
Module
):
"""Refactor wenet/transformer/attention.py::MultiHeadedAttention
NOTE(xcsong): Only support attention_class == MultiHeadedAttention,
we do not consider RelPositionMultiHeadedAttention currently.
"""
def
__init__
(
self
,
module
,
chunk_size
,
left_chunks
):
super
().
__init__
()
# Unchanged submodules and attributes
original
=
copy
.
deepcopy
(
module
)
self
.
d_k
=
module
.
d_k
self
.
h
=
module
.
h
n_feat
=
self
.
d_k
*
self
.
h
self
.
chunk_size
=
chunk_size
self
.
left_chunks
=
left_chunks
self
.
time
=
chunk_size
*
(
left_chunks
+
1
)
self
.
activation
=
torch
.
nn
.
Softmax
(
dim
=-
1
)
# 1. Modify self.linear_x
self
.
linear_q
=
BPULinear
(
module
.
linear_q
)
self
.
linear_k
=
BPULinear
(
module
.
linear_k
)
self
.
linear_v
=
BPULinear
(
module
.
linear_v
)
self
.
linear_out
=
BPULinear
(
module
.
linear_out
)
# 2. denom
self
.
register_buffer
(
"denom"
,
torch
.
full
((
1
,
self
.
h
,
1
,
1
),
1.0
/
math
.
sqrt
(
self
.
d_k
)))
self
.
check_equal
(
original
)
def
check_equal
(
self
,
module
):
random_data
=
torch
.
randn
(
1
,
self
.
chunk_size
,
self
.
d_k
*
self
.
h
)
mask
=
torch
.
ones
((
1
,
self
.
h
,
self
.
chunk_size
,
self
.
time
),
dtype
=
torch
.
bool
)
cache
=
torch
.
zeros
(
1
,
self
.
h
,
self
.
chunk_size
*
self
.
left_chunks
,
self
.
d_k
*
2
)
original_out
,
original_cache
=
module
(
random_data
,
random_data
,
random_data
,
mask
[:,
0
,
:,
:],
torch
.
empty
(
0
),
cache
)
random_data
=
random_data
.
transpose
(
1
,
2
).
unsqueeze
(
2
)
cache
=
cache
.
reshape
(
1
,
self
.
h
,
self
.
d_k
*
2
,
self
.
chunk_size
*
self
.
left_chunks
)
new_out
,
new_cache
=
self
.
forward
(
random_data
,
random_data
,
random_data
,
mask
,
cache
)
np
.
testing
.
assert_allclose
(
to_numpy
(
original_out
),
to_numpy
(
new_out
.
squeeze
(
2
).
transpose
(
1
,
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
np
.
testing
.
assert_allclose
(
to_numpy
(
original_cache
),
to_numpy
(
new_cache
.
transpose
(
2
,
3
)),
rtol
=
1e-02
,
atol
=
1e-03
)
def
forward
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
cache
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute scaled dot product attention.
Args:
q (torch.Tensor): Query tensor (#batch, size, 1, chunk_size).
k (torch.Tensor): Key tensor (#batch, size, 1, chunk_size).
v (torch.Tensor): Value tensor (#batch, size, 1, chunk_size).
mask (torch.Tensor): Mask tensor,
(#batch, head, chunk_size, cache_t + chunk_size).
cache (torch.Tensor): Cache tensor
(1, head, d_k * 2, cache_t),
where `cache_t == chunk_size * left_chunks`.
Returns:
torch.Tensor: Output tensor (#batch, size, 1, chunk_size).
torch.Tensor: Cache tensor
(1, head, d_k * 2, cache_t + chunk_size)
where `cache_t == chunk_size * left_chunks`
"""
# 1. Forward QKV
q
=
self
.
linear_q
(
q
)
# (1, d, 1, c) d == size, c == chunk_size
k
=
self
.
linear_k
(
k
)
# (1, d, 1, c)
v
=
self
.
linear_v
(
v
)
# (1, d, 1, c)
q
=
q
.
view
(
1
,
self
.
h
,
self
.
d_k
,
self
.
chunk_size
)
k
=
k
.
view
(
1
,
self
.
h
,
self
.
d_k
,
self
.
chunk_size
)
v
=
v
.
view
(
1
,
self
.
h
,
self
.
d_k
,
self
.
chunk_size
)
q
=
q
.
transpose
(
2
,
3
)
# (batch, head, time1, d_k)
k_cache
,
v_cache
=
torch
.
split
(
cache
,
cache
.
size
(
2
)
//
2
,
dim
=
2
)
k
=
torch
.
cat
((
k_cache
,
k
),
dim
=
3
)
v
=
torch
.
cat
((
v_cache
,
v
),
dim
=
3
)
new_cache
=
torch
.
cat
((
k
,
v
),
dim
=
2
)
# 2. (Q^T)K
scores
=
torch
.
matmul
(
q
,
k
)
*
self
.
denom
# (#b, n_head, time1, time2)
# 3. Forward attention
mask
=
mask
.
eq
(
0
)
scores
=
scores
.
masked_fill
(
mask
,
-
float
(
'inf'
))
attn
=
self
.
activation
(
scores
).
masked_fill
(
mask
,
0.0
)
attn
=
attn
.
transpose
(
2
,
3
)
x
=
torch
.
matmul
(
v
,
attn
)
x
=
x
.
view
(
1
,
self
.
d_k
*
self
.
h
,
1
,
self
.
chunk_size
)
x_out
=
self
.
linear_out
(
x
)
return
x_out
,
new_cache
class
BPUConvolution
(
torch
.
nn
.
Module
):
"""Refactor wenet/transformer/convolution.py::ConvolutionModule
NOTE(xcsong): Only suport use_layer_norm == False
"""
def
__init__
(
self
,
module
):
super
().
__init__
()
# Unchanged submodules and attributes
original
=
copy
.
deepcopy
(
module
)
self
.
lorder
=
module
.
lorder
self
.
use_layer_norm
=
False
self
.
activation
=
module
.
activation
channels
=
module
.
pointwise_conv1
.
weight
.
size
(
1
)
self
.
channels
=
channels
kernel_size
=
module
.
depthwise_conv
.
weight
.
size
(
2
)
assert
module
.
use_layer_norm
is
False
# 1. Modify self.pointwise_conv1
self
.
pointwise_conv1
=
BPULinear
(
module
.
pointwise_conv1
,
True
)
# 2. Modify self.depthwise_conv
self
.
depthwise_conv
=
torch
.
nn
.
Conv2d
(
channels
,
channels
,
(
1
,
kernel_size
),
stride
=
1
,
groups
=
channels
)
self
.
depthwise_conv
.
weight
=
torch
.
nn
.
Parameter
(
module
.
depthwise_conv
.
weight
.
unsqueeze
(
-
2
))
self
.
depthwise_conv
.
bias
=
torch
.
nn
.
Parameter
(
module
.
depthwise_conv
.
bias
)
# 3. Modify self.norm, Only support batchnorm2d
self
.
norm
=
torch
.
nn
.
BatchNorm2d
(
channels
)
self
.
norm
.
training
=
False
self
.
norm
.
num_features
=
module
.
norm
.
num_features
self
.
norm
.
eps
=
module
.
norm
.
eps
self
.
norm
.
momentum
=
module
.
norm
.
momentum
self
.
norm
.
weight
=
torch
.
nn
.
Parameter
(
module
.
norm
.
weight
)
self
.
norm
.
bias
=
torch
.
nn
.
Parameter
(
module
.
norm
.
bias
)
self
.
norm
.
running_mean
=
module
.
norm
.
running_mean
self
.
norm
.
running_var
=
module
.
norm
.
running_var
# 4. Modify self.pointwise_conv2
self
.
pointwise_conv2
=
BPULinear
(
module
.
pointwise_conv2
,
True
)
# 5. Identity conv, for running `concat` on BPU
self
.
identity
=
BPUIdentity
(
channels
)
self
.
check_equal
(
original
)
def
check_equal
(
self
,
module
):
random_data
=
torch
.
randn
(
1
,
8
,
self
.
channels
)
cache
=
torch
.
zeros
((
1
,
self
.
channels
,
self
.
lorder
))
original_out
,
original_cache
=
module
(
random_data
,
cache
=
cache
)
random_data
=
random_data
.
transpose
(
1
,
2
).
unsqueeze
(
2
)
cache
=
cache
.
unsqueeze
(
2
)
new_out
,
new_cache
=
self
.
forward
(
random_data
,
cache
)
np
.
testing
.
assert_allclose
(
to_numpy
(
original_out
),
to_numpy
(
new_out
.
squeeze
(
2
).
transpose
(
1
,
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
np
.
testing
.
assert_allclose
(
to_numpy
(
original_cache
),
to_numpy
(
new_cache
.
squeeze
(
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cache
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, channels, 1, chunk_size).
cache (torch.Tensor): left context cache, it is only
used in causal convolution (#batch, channels, 1, cache_t).
Returns:
torch.Tensor: Output tensor (#batch, channels, 1, chunk_size).
torch.Tensor: Cache tensor (#batch, channels, 1, cache_t).
"""
# Concat cache
x
=
torch
.
cat
((
self
.
identity
(
cache
),
self
.
identity
(
x
)),
dim
=
3
)
new_cache
=
x
[:,
:,
:,
-
self
.
lorder
:]
# GLU mechanism
x
=
self
.
pointwise_conv1
(
x
)
# (batch, 2*channel, 1, dim)
x
=
torch
.
nn
.
functional
.
glu
(
x
,
dim
=
1
)
# (b, channel, 1, dim)
# Depthwise Conv
x
=
self
.
depthwise_conv
(
x
)
x
=
self
.
activation
(
self
.
norm
(
x
))
x
=
self
.
pointwise_conv2
(
x
)
return
x
,
new_cache
class
BPUFFN
(
torch
.
nn
.
Module
):
"""Refactor wenet/transformer/positionwise_feed_forward.py::PositionwiseFeedForward
"""
def
__init__
(
self
,
module
):
super
().
__init__
()
# Unchanged submodules and attributes
original
=
copy
.
deepcopy
(
module
)
self
.
activation
=
module
.
activation
# 1. Modify self.w_x
self
.
w_1
=
BPULinear
(
module
.
w_1
)
self
.
w_2
=
BPULinear
(
module
.
w_2
)
self
.
check_equal
(
original
)
def
check_equal
(
self
,
module
):
random_data
=
torch
.
randn
(
1
,
8
,
self
.
w_1
.
idim
)
original_out
=
module
(
random_data
)
random_data
=
random_data
.
transpose
(
1
,
2
).
unsqueeze
(
2
)
new_out
=
self
.
forward
(
random_data
)
np
.
testing
.
assert_allclose
(
to_numpy
(
original_out
),
to_numpy
(
new_out
.
squeeze
(
2
).
transpose
(
1
,
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Forward function.
Args:
xs: input tensor (B, D, 1, L)
Returns:
output tensor, (B, D, 1, L)
"""
return
self
.
w_2
(
self
.
activation
(
self
.
w_1
(
x
)))
class
BPUConformerEncoderLayer
(
torch
.
nn
.
Module
):
"""Refactor wenet/transformer/encoder_layer.py::ConformerEncoderLayer
"""
def
__init__
(
self
,
module
,
chunk_size
,
left_chunks
,
ln_run_on_bpu
=
False
):
super
().
__init__
()
# Unchanged submodules and attributes
original
=
copy
.
deepcopy
(
module
)
self
.
size
=
module
.
size
assert
module
.
normalize_before
is
True
assert
module
.
concat_after
is
False
# 1. Modify submodules
self
.
feed_forward_macaron
=
BPUFFN
(
module
.
feed_forward_macaron
)
self
.
self_attn
=
BPUMultiHeadedAttention
(
module
.
self_attn
,
chunk_size
,
left_chunks
)
self
.
conv_module
=
BPUConvolution
(
module
.
conv_module
)
self
.
feed_forward
=
BPUFFN
(
module
.
feed_forward
)
# 2. Modify norms
self
.
norm_ff
=
BPULayerNorm
(
module
.
norm_ff
,
chunk_size
,
ln_run_on_bpu
)
self
.
norm_mha
=
BPULayerNorm
(
module
.
norm_mha
,
chunk_size
,
ln_run_on_bpu
)
self
.
norm_ff_macron
=
BPULayerNorm
(
module
.
norm_ff_macaron
,
chunk_size
,
ln_run_on_bpu
)
self
.
norm_conv
=
BPULayerNorm
(
module
.
norm_conv
,
chunk_size
,
ln_run_on_bpu
)
self
.
norm_final
=
BPULayerNorm
(
module
.
norm_final
,
chunk_size
,
ln_run_on_bpu
)
# 3. 4-D ff_scale
self
.
register_buffer
(
"ff_scale"
,
torch
.
full
((
1
,
self
.
size
,
1
,
1
),
module
.
ff_scale
))
self
.
check_equal
(
original
)
def
check_equal
(
self
,
module
):
time1
=
self
.
self_attn
.
chunk_size
time2
=
self
.
self_attn
.
time
h
,
d_k
=
self
.
self_attn
.
h
,
self
.
self_attn
.
d_k
random_x
=
torch
.
randn
(
1
,
time1
,
self
.
size
)
att_mask
=
torch
.
ones
(
1
,
h
,
time1
,
time2
)
att_cache
=
torch
.
zeros
(
1
,
h
,
time2
-
time1
,
d_k
*
2
)
cnn_cache
=
torch
.
zeros
(
1
,
self
.
size
,
self
.
conv_module
.
lorder
)
original_x
,
_
,
original_att_cache
,
original_cnn_cache
=
module
(
random_x
,
att_mask
[:,
0
,
:,
:],
torch
.
empty
(
0
),
att_cache
=
att_cache
,
cnn_cache
=
cnn_cache
)
random_x
=
random_x
.
transpose
(
1
,
2
).
unsqueeze
(
2
)
att_cache
=
att_cache
.
reshape
(
1
,
h
,
d_k
*
2
,
time2
-
time1
)
cnn_cache
=
cnn_cache
.
unsqueeze
(
2
)
new_x
,
new_att_cache
,
new_cnn_cache
=
self
.
forward
(
random_x
,
att_mask
,
att_cache
,
cnn_cache
)
np
.
testing
.
assert_allclose
(
to_numpy
(
original_att_cache
),
to_numpy
(
new_att_cache
.
transpose
(
2
,
3
)),
rtol
=
1e-02
,
atol
=
1e-03
)
np
.
testing
.
assert_allclose
(
to_numpy
(
original_x
),
to_numpy
(
new_x
.
squeeze
(
2
).
transpose
(
1
,
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
np
.
testing
.
assert_allclose
(
to_numpy
(
original_cnn_cache
),
to_numpy
(
new_cnn_cache
.
squeeze
(
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
att_mask
:
torch
.
Tensor
,
att_cache
:
torch
.
Tensor
,
cnn_cache
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute encoded features.
Args:
x (torch.Tensor): (#batch, size, 1, chunk_size)
att_mask (torch.Tensor): Mask tensor for the input
(#batch, head, chunk_size, cache_t1 + chunk_size),
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
(#batch=1, head, d_k * 2, cache_t1), head * d_k == size.
cnn_cache (torch.Tensor): Convolution cache in conformer layer
(#batch=1, size, 1, cache_t2)
Returns:
torch.Tensor: Output tensor (#batch, size, 1, chunk_size).
torch.Tensor: att_cache tensor,
(1, head, d_k * 2, cache_t1 + chunk_size).
torch.Tensor: cnn_cahce tensor (#batch, size, 1, cache_t2).
"""
# 1. ffn_macaron
residual
=
x
x
=
self
.
norm_ff_macron
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
feed_forward_macaron
(
x
)
# 2. attention
residual
=
x
x
=
self
.
norm_mha
(
x
)
x_att
,
new_att_cache
=
self
.
self_attn
(
x
,
x
,
x
,
att_mask
,
att_cache
)
x
=
residual
+
x_att
# 3. convolution
residual
=
x
x
=
self
.
norm_conv
(
x
)
x
,
new_cnn_cache
=
self
.
conv_module
(
x
,
cnn_cache
)
x
=
residual
+
x
# 4. ffn
residual
=
x
x
=
self
.
norm_ff
(
x
)
x
=
residual
+
self
.
ff_scale
*
self
.
feed_forward
(
x
)
# 5. final post-norm
x
=
self
.
norm_final
(
x
)
return
x
,
new_att_cache
,
new_cnn_cache
class
BPUConformerEncoder
(
torch
.
nn
.
Module
):
"""Refactor wenet/transformer/encoder.py::ConformerEncoder
"""
def
__init__
(
self
,
module
,
chunk_size
,
left_chunks
,
ln_run_on_bpu
=
False
):
super
().
__init__
()
# Unchanged submodules and attributes
original
=
copy
.
deepcopy
(
module
)
output_size
=
module
.
output_size
()
self
.
_output_size
=
module
.
output_size
()
self
.
after_norm
=
module
.
after_norm
self
.
chunk_size
=
chunk_size
self
.
left_chunks
=
left_chunks
self
.
head
=
module
.
encoders
[
0
].
self_attn
.
h
self
.
layers
=
len
(
module
.
encoders
)
# 1. Modify submodules
self
.
global_cmvn
=
BPUGlobalCMVN
(
module
.
global_cmvn
)
self
.
embed
=
BPUConv2dSubsampling8
(
module
.
embed
)
self
.
encoders
=
torch
.
nn
.
ModuleList
()
for
layer
in
module
.
encoders
:
self
.
encoders
.
append
(
BPUConformerEncoderLayer
(
layer
,
chunk_size
,
left_chunks
,
ln_run_on_bpu
))
# 2. Auxiliary conv
self
.
identity_cnncache
=
BPUIdentity
(
output_size
)
self
.
check_equal
(
original
)
def
check_equal
(
self
,
module
):
time1
=
self
.
encoders
[
0
].
self_attn
.
chunk_size
time2
=
self
.
encoders
[
0
].
self_attn
.
time
layers
=
self
.
layers
h
,
d_k
=
self
.
head
,
self
.
encoders
[
0
].
self_attn
.
d_k
decoding_window
=
(
self
.
chunk_size
-
1
)
*
\
module
.
embed
.
subsampling_rate
+
\
module
.
embed
.
right_context
+
1
lorder
=
self
.
encoders
[
0
].
conv_module
.
lorder
random_x
=
torch
.
randn
(
1
,
decoding_window
,
80
)
att_mask
=
torch
.
ones
(
1
,
h
,
time1
,
time2
)
att_cache
=
torch
.
zeros
(
layers
,
h
,
time2
-
time1
,
d_k
*
2
)
cnn_cache
=
torch
.
zeros
(
layers
,
1
,
self
.
_output_size
,
lorder
)
orig_x
,
orig_att_cache
,
orig_cnn_cache
=
module
.
forward_chunk
(
random_x
,
0
,
time2
-
time1
,
att_mask
=
att_mask
[:,
0
,
:,
:],
att_cache
=
att_cache
,
cnn_cache
=
cnn_cache
)
random_x
=
random_x
.
unsqueeze
(
0
)
att_cache
=
att_cache
.
reshape
(
1
,
h
*
layers
,
d_k
*
2
,
time2
-
time1
)
cnn_cache
=
cnn_cache
.
reshape
(
1
,
self
.
_output_size
,
layers
,
lorder
)
new_x
,
new_att_cache
,
new_cnn_cache
=
self
.
forward
(
random_x
,
att_cache
,
cnn_cache
,
att_mask
)
caches
=
torch
.
split
(
new_att_cache
,
h
,
dim
=
1
)
caches
=
[
c
.
transpose
(
2
,
3
)
for
c
in
caches
]
np
.
testing
.
assert_allclose
(
to_numpy
(
orig_att_cache
),
to_numpy
(
torch
.
cat
(
caches
,
dim
=
0
)),
rtol
=
1e-02
,
atol
=
1e-03
)
np
.
testing
.
assert_allclose
(
to_numpy
(
orig_x
),
to_numpy
(
new_x
.
squeeze
(
2
).
transpose
(
1
,
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
np
.
testing
.
assert_allclose
(
to_numpy
(
orig_cnn_cache
),
to_numpy
(
new_cnn_cache
.
transpose
(
0
,
2
).
transpose
(
1
,
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
def
forward
(
self
,
xs
:
torch
.
Tensor
,
att_cache
:
torch
.
Tensor
,
cnn_cache
:
torch
.
Tensor
,
att_mask
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
""" Forward just one chunk
Args:
xs (torch.Tensor): chunk input, with shape (b=1, 1, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate +
\
subsample.right_context + 1`
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(1, head * elayers, d_k * 2, cache_t1), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(1, hidden-dim, elayers, cache_t2), where
`cache_t2 == cnn.lorder - 1`
att_mask (torch.Tensor): Mask tensor for the input
(#batch, head, chunk_size, cache_t1 + chunk_size),
Returns:
torch.Tensor: output of current input xs,
with shape (b=1, hidden-dim, 1, chunk_size).
torch.Tensor: new attention cache required for next chunk, with
same shape as the original att_cache.
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
# xs: (B, 1, time, mel_dim) -> (B, 1, mel_dim, time)
xs
=
xs
.
transpose
(
2
,
3
)
xs
=
self
.
global_cmvn
(
xs
)
# xs: (B, 1, mel_dim, time) -> (B, hidden_dim, 1, chunk_size)
xs
=
self
.
embed
(
xs
)
att_cache
=
torch
.
split
(
att_cache
,
self
.
head
,
dim
=
1
)
cnn_cache
=
self
.
identity_cnncache
(
cnn_cache
)
cnn_cache
=
torch
.
split
(
cnn_cache
,
1
,
dim
=
2
)
r_att_cache
=
[]
r_cnn_cache
=
[]
for
i
,
layer
in
enumerate
(
self
.
encoders
):
xs
,
new_att_cache
,
new_cnn_cache
=
layer
(
xs
,
att_mask
,
att_cache
=
att_cache
[
i
],
cnn_cache
=
cnn_cache
[
i
])
r_att_cache
.
append
(
new_att_cache
[:,
:,
:,
self
.
chunk_size
:])
r_cnn_cache
.
append
(
new_cnn_cache
)
r_att_cache
=
torch
.
cat
(
r_att_cache
,
dim
=
1
)
r_cnn_cache
=
self
.
identity_cnncache
(
torch
.
cat
(
r_cnn_cache
,
dim
=
2
))
xs
=
xs
.
squeeze
(
2
).
transpose
(
1
,
2
).
contiguous
()
xs
=
self
.
after_norm
(
xs
)
# NOTE(xcsong): 4D in, 4D out to meet the requirment of CTC input.
xs
=
xs
.
transpose
(
1
,
2
).
contiguous
().
unsqueeze
(
2
)
# (B, C, 1, T)
return
(
xs
,
r_att_cache
,
r_cnn_cache
)
class
BPUCTC
(
torch
.
nn
.
Module
):
"""Refactor wenet/transformer/ctc.py::CTC
"""
def
__init__
(
self
,
module
):
super
().
__init__
()
# Unchanged submodules and attributes
original
=
copy
.
deepcopy
(
module
)
self
.
idim
=
module
.
ctc_lo
.
weight
.
size
(
1
)
num_class
=
module
.
ctc_lo
.
weight
.
size
(
0
)
# 1. Modify self.ctc_lo, Split final projection to meet the
# requirment of maximum in/out channels (2048 for XJ3)
self
.
ctc_lo
=
torch
.
nn
.
ModuleList
()
self
.
split_size
=
[]
num_split
=
(
num_class
-
1
)
//
2048
+
1
for
idx
in
range
(
num_split
):
out_channel
=
min
(
num_class
,
(
idx
+
1
)
*
2048
)
-
idx
*
2048
conv_ele
=
torch
.
nn
.
Conv2d
(
self
.
idim
,
out_channel
,
1
,
1
)
self
.
ctc_lo
.
append
(
conv_ele
)
self
.
split_size
.
append
(
out_channel
)
orig_weight
=
torch
.
split
(
module
.
ctc_lo
.
weight
,
self
.
split_size
,
dim
=
0
)
orig_bias
=
torch
.
split
(
module
.
ctc_lo
.
bias
,
self
.
split_size
,
dim
=
0
)
for
i
,
(
w
,
b
)
in
enumerate
(
zip
(
orig_weight
,
orig_bias
)):
w
=
w
.
unsqueeze
(
2
).
unsqueeze
(
3
)
self
.
ctc_lo
[
i
].
weight
=
torch
.
nn
.
Parameter
(
w
)
self
.
ctc_lo
[
i
].
bias
=
torch
.
nn
.
Parameter
(
b
)
self
.
check_equal
(
original
)
def
check_equal
(
self
,
module
):
random_data
=
torch
.
randn
(
1
,
100
,
self
.
idim
)
original_result
=
module
.
ctc_lo
(
random_data
)
random_data
=
random_data
.
transpose
(
1
,
2
).
unsqueeze
(
2
)
new_result
=
self
.
forward
(
random_data
)
np
.
testing
.
assert_allclose
(
to_numpy
(
original_result
),
to_numpy
(
new_result
.
squeeze
(
2
).
transpose
(
1
,
2
)),
rtol
=
1e-02
,
atol
=
1e-03
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""frame activations, without softmax.
Args:
Tensor x: 4d tensor (B, hidden_dim, 1, chunk_size)
Returns:
torch.Tensor: (B, num_class, 1, chunk_size)
"""
out
=
[]
for
i
,
layer
in
enumerate
(
self
.
ctc_lo
):
out
.
append
(
layer
(
x
))
out
=
torch
.
cat
(
out
,
dim
=
1
)
return
out
def
export_encoder
(
asr_model
,
args
):
logger
.
info
(
"Stage-1: export encoder"
)
decode_window
,
mel_dim
=
args
.
decoding_window
,
args
.
feature_size
encoder
=
BPUConformerEncoder
(
asr_model
.
encoder
,
args
.
chunk_size
,
args
.
num_decoding_left_chunks
,
args
.
ln_run_on_bpu
)
encoder
.
eval
()
encoder_outpath
=
os
.
path
.
join
(
args
.
output_dir
,
'encoder.onnx'
)
logger
.
info
(
"Stage-1.1: prepare inputs for encoder"
)
chunk
=
torch
.
randn
((
1
,
1
,
decode_window
,
mel_dim
))
required_cache_size
=
encoder
.
chunk_size
*
encoder
.
left_chunks
kv_time
=
required_cache_size
+
encoder
.
chunk_size
hidden
,
layers
=
encoder
.
_output_size
,
len
(
encoder
.
encoders
)
head
=
encoder
.
encoders
[
0
].
self_attn
.
h
d_k
=
hidden
//
head
lorder
=
encoder
.
encoders
[
0
].
conv_module
.
lorder
att_cache
=
torch
.
zeros
(
1
,
layers
*
head
,
d_k
*
2
,
required_cache_size
)
att_mask
=
torch
.
ones
((
1
,
head
,
encoder
.
chunk_size
,
kv_time
))
att_mask
[:,
:,
:,
:
required_cache_size
]
=
0
cnn_cache
=
torch
.
zeros
((
1
,
hidden
,
layers
,
lorder
))
inputs
=
(
chunk
,
att_cache
,
cnn_cache
,
att_mask
)
logger
.
info
(
"chunk.size(): {} att_cache.size(): {} "
"cnn_cache.size(): {} att_mask.size(): {}"
.
format
(
list
(
chunk
.
size
()),
list
(
att_cache
.
size
()),
list
(
cnn_cache
.
size
()),
list
(
att_mask
.
size
())))
logger
.
info
(
"Stage-1.2: torch.onnx.export"
)
# NOTE(xcsong): Below attributes will be used in
# onnx2horizonbin.py::generate_config()
attributes
=
{}
attributes
[
'input_name'
]
=
"chunk;att_cache;cnn_cache;att_mask"
attributes
[
'output_name'
]
=
"output;r_att_cache;r_cnn_cache"
attributes
[
'input_type'
]
=
"featuremap;featuremap;featuremap;featuremap"
attributes
[
'norm_type'
]
=
\
"no_preprocess;no_preprocess;no_preprocess;no_preprocess"
attributes
[
'input_layout_train'
]
=
"NCHW;NCHW;NCHW;NCHW"
attributes
[
'input_layout_rt'
]
=
"NCHW;NCHW;NCHW;NCHW"
attributes
[
'input_shape'
]
=
\
"{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{};{}x{}x{}x{}"
.
format
(
chunk
.
size
(
0
),
chunk
.
size
(
1
),
chunk
.
size
(
2
),
chunk
.
size
(
3
),
att_cache
.
size
(
0
),
att_cache
.
size
(
1
),
att_cache
.
size
(
2
),
att_cache
.
size
(
3
),
cnn_cache
.
size
(
0
),
cnn_cache
.
size
(
1
),
cnn_cache
.
size
(
2
),
cnn_cache
.
size
(
3
),
att_mask
.
size
(
0
),
att_mask
.
size
(
1
),
att_mask
.
size
(
2
),
att_mask
.
size
(
3
)
)
torch
.
onnx
.
export
(
# NOTE(xcsong): only support opset==11
encoder
,
inputs
,
encoder_outpath
,
opset_version
=
11
,
export_params
=
True
,
do_constant_folding
=
True
,
input_names
=
attributes
[
'input_name'
].
split
(
';'
),
output_names
=
attributes
[
'output_name'
].
split
(
';'
),
dynamic_axes
=
None
,
verbose
=
False
)
onnx_encoder
=
onnx
.
load
(
encoder_outpath
)
for
k
in
vars
(
args
):
meta
=
onnx_encoder
.
metadata_props
.
add
()
meta
.
key
,
meta
.
value
=
str
(
k
),
str
(
getattr
(
args
,
k
))
for
k
in
attributes
:
meta
=
onnx_encoder
.
metadata_props
.
add
()
meta
.
key
,
meta
.
value
=
str
(
k
),
str
(
attributes
[
k
])
onnx
.
checker
.
check_model
(
onnx_encoder
)
onnx
.
helper
.
printable_graph
(
onnx_encoder
.
graph
)
onnx
.
save
(
onnx_encoder
,
encoder_outpath
)
print_input_output_info
(
onnx_encoder
,
"onnx_encoder"
)
logger
.
info
(
'Export onnx_encoder, done! see {}'
.
format
(
encoder_outpath
))
logger
.
info
(
"Stage-1.3: check onnx_encoder and torch_encoder"
)
torch_output
=
[]
torch_chunk
,
torch_att_mask
=
copy
.
deepcopy
(
chunk
),
copy
.
deepcopy
(
att_mask
)
torch_att_cache
=
copy
.
deepcopy
(
att_cache
)
torch_cnn_cache
=
copy
.
deepcopy
(
cnn_cache
)
for
i
in
range
(
10
):
logger
.
info
(
"torch chunk-{}: {}, att_cache: {}, cnn_cache: {}"
", att_mask: {}"
.
format
(
i
,
list
(
torch_chunk
.
size
()),
list
(
torch_att_cache
.
size
()),
list
(
torch_cnn_cache
.
size
()),
list
(
torch_att_mask
.
size
())))
torch_att_mask
[:,
:,
:,
-
(
encoder
.
chunk_size
*
(
i
+
1
)):]
=
1
out
,
torch_att_cache
,
torch_cnn_cache
=
encoder
(
torch_chunk
,
torch_att_cache
,
torch_cnn_cache
,
torch_att_mask
)
torch_output
.
append
(
out
)
torch_output
=
torch
.
cat
(
torch_output
,
dim
=-
1
)
onnx_output
=
[]
onnx_chunk
,
onnx_att_mask
=
to_numpy
(
chunk
),
to_numpy
(
att_mask
)
onnx_att_cache
=
to_numpy
(
att_cache
)
onnx_cnn_cache
=
to_numpy
(
cnn_cache
)
ort_session
=
onnxruntime
.
InferenceSession
(
encoder_outpath
)
input_names
=
[
node
.
name
for
node
in
onnx_encoder
.
graph
.
input
]
for
i
in
range
(
10
):
logger
.
info
(
"onnx chunk-{}: {}, att_cache: {}, cnn_cache: {},"
" att_mask: {}"
.
format
(
i
,
onnx_chunk
.
shape
,
onnx_att_cache
.
shape
,
onnx_cnn_cache
.
shape
,
onnx_att_mask
.
shape
))
onnx_att_mask
[:,
:,
:,
-
(
encoder
.
chunk_size
*
(
i
+
1
)):]
=
1
ort_inputs
=
{
'chunk'
:
onnx_chunk
,
'att_cache'
:
onnx_att_cache
,
'cnn_cache'
:
onnx_cnn_cache
,
'att_mask'
:
onnx_att_mask
,
}
ort_outs
=
ort_session
.
run
(
None
,
ort_inputs
)
onnx_att_cache
,
onnx_cnn_cache
=
ort_outs
[
1
],
ort_outs
[
2
]
onnx_output
.
append
(
ort_outs
[
0
])
onnx_output
=
np
.
concatenate
(
onnx_output
,
axis
=-
1
)
np
.
testing
.
assert_allclose
(
to_numpy
(
torch_output
),
onnx_output
,
rtol
=
1e-03
,
atol
=
1e-04
)
meta
=
ort_session
.
get_modelmeta
()
logger
.
info
(
"custom_metadata_map={}"
.
format
(
meta
.
custom_metadata_map
))
logger
.
info
(
"Check onnx_encoder, pass!"
)
return
encoder
,
ort_session
def
export_ctc
(
asr_model
,
args
):
logger
.
info
(
"Stage-2: export ctc"
)
ctc
=
BPUCTC
(
asr_model
.
ctc
).
eval
()
ctc_outpath
=
os
.
path
.
join
(
args
.
output_dir
,
'ctc.onnx'
)
logger
.
info
(
"Stage-2.1: prepare inputs for ctc"
)
hidden
=
torch
.
randn
((
1
,
args
.
output_size
,
1
,
args
.
chunk_size
))
logger
.
info
(
"Stage-2.2: torch.onnx.export"
)
# NOTE(xcsong): Below attributes will be used in
# onnx2horizonbin.py::generate_config()
attributes
=
{}
attributes
[
'input_name'
],
attributes
[
'input_type'
]
=
"hidden"
,
"featuremap"
attributes
[
'norm_type'
]
=
"no_preprocess"
attributes
[
'input_layout_train'
]
=
"NCHW"
attributes
[
'input_layout_rt'
]
=
"NCHW"
attributes
[
'input_shape'
]
=
"{}x{}x{}x{}"
.
format
(
hidden
.
size
(
0
),
hidden
.
size
(
1
),
hidden
.
size
(
2
),
hidden
.
size
(
3
),
)
torch
.
onnx
.
export
(
ctc
,
hidden
,
ctc_outpath
,
opset_version
=
11
,
export_params
=
True
,
do_constant_folding
=
True
,
input_names
=
[
'hidden'
],
output_names
=
[
'probs'
],
dynamic_axes
=
None
,
verbose
=
False
)
onnx_ctc
=
onnx
.
load
(
ctc_outpath
)
for
k
in
vars
(
args
):
meta
=
onnx_ctc
.
metadata_props
.
add
()
meta
.
key
,
meta
.
value
=
str
(
k
),
str
(
getattr
(
args
,
k
))
for
k
in
attributes
:
meta
=
onnx_ctc
.
metadata_props
.
add
()
meta
.
key
,
meta
.
value
=
str
(
k
),
str
(
attributes
[
k
])
onnx
.
checker
.
check_model
(
onnx_ctc
)
onnx
.
helper
.
printable_graph
(
onnx_ctc
.
graph
)
onnx
.
save
(
onnx_ctc
,
ctc_outpath
)
print_input_output_info
(
onnx_ctc
,
"onnx_ctc"
)
logger
.
info
(
'Export onnx_ctc, done! see {}'
.
format
(
ctc_outpath
))
logger
.
info
(
"Stage-2.3: check onnx_ctc and torch_ctc"
)
torch_output
=
ctc
(
hidden
)
ort_session
=
onnxruntime
.
InferenceSession
(
ctc_outpath
)
onnx_output
=
ort_session
.
run
(
None
,
{
'hidden'
:
to_numpy
(
hidden
)})
np
.
testing
.
assert_allclose
(
to_numpy
(
torch_output
),
onnx_output
[
0
],
rtol
=
1e-03
,
atol
=
1e-04
)
meta
=
ort_session
.
get_modelmeta
()
logger
.
info
(
"custom_metadata_map={}"
.
format
(
meta
.
custom_metadata_map
))
logger
.
info
(
"Check onnx_ctc, pass!"
)
return
ctc
,
ort_session
def
export_decoder
(
asr_model
,
args
):
logger
.
info
(
"Currently, Decoder is not supported."
)
if
__name__
==
'__main__'
:
torch
.
manual_seed
(
777
)
args
=
get_args
()
args
.
ln_run_on_bpu
=
False
# NOTE(xcsong): XJ3 BPU only support static shapes
assert
args
.
chunk_size
>
0
assert
args
.
num_decoding_left_chunks
>
0
os
.
system
(
"mkdir -p "
+
args
.
output_dir
)
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'-1'
with
open
(
args
.
config
,
'r'
)
as
fin
:
configs
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
FullLoader
)
model
=
init_model
(
configs
)
load_checkpoint
(
model
,
args
.
checkpoint
)
model
.
eval
()
print
(
model
)
args
.
feature_size
=
configs
[
'input_dim'
]
args
.
output_size
=
model
.
encoder
.
output_size
()
args
.
decoding_window
=
(
args
.
chunk_size
-
1
)
*
\
model
.
encoder
.
embed
.
subsampling_rate
+
\
model
.
encoder
.
embed
.
right_context
+
1
export_encoder
(
model
,
args
)
export_ctc
(
model
,
args
)
export_decoder
(
model
,
args
)
examples/aishell/s0/wenet/bin/export_onnx_cpu.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2022, Xingchen Song (sxc19@mails.tsinghua.edu.cn)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
argparse
import
os
import
copy
import
sys
import
torch
import
yaml
import
numpy
as
np
from
wenet.utils.checkpoint
import
load_checkpoint
from
wenet.utils.init_model
import
init_model
try
:
import
onnx
import
onnxruntime
from
onnxruntime.quantization
import
quantize_dynamic
,
QuantType
except
ImportError
:
print
(
'Please install onnx and onnxruntime!'
)
sys
.
exit
(
1
)
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'export your script model'
)
parser
.
add_argument
(
'--config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--checkpoint'
,
required
=
True
,
help
=
'checkpoint model'
)
parser
.
add_argument
(
'--output_dir'
,
required
=
True
,
help
=
'output directory'
)
parser
.
add_argument
(
'--chunk_size'
,
required
=
True
,
type
=
int
,
help
=
'decoding chunk size'
)
parser
.
add_argument
(
'--num_decoding_left_chunks'
,
required
=
True
,
type
=
int
,
help
=
'cache chunks'
)
parser
.
add_argument
(
'--reverse_weight'
,
default
=
0.5
,
type
=
float
,
help
=
'reverse_weight in attention_rescoing'
)
args
=
parser
.
parse_args
()
return
args
def
to_numpy
(
tensor
):
if
tensor
.
requires_grad
:
return
tensor
.
detach
().
cpu
().
numpy
()
else
:
return
tensor
.
cpu
().
numpy
()
def
print_input_output_info
(
onnx_model
,
name
,
prefix
=
"
\t\t
"
):
input_names
=
[
node
.
name
for
node
in
onnx_model
.
graph
.
input
]
input_shapes
=
[[
d
.
dim_value
for
d
in
node
.
type
.
tensor_type
.
shape
.
dim
]
for
node
in
onnx_model
.
graph
.
input
]
output_names
=
[
node
.
name
for
node
in
onnx_model
.
graph
.
output
]
output_shapes
=
[[
d
.
dim_value
for
d
in
node
.
type
.
tensor_type
.
shape
.
dim
]
for
node
in
onnx_model
.
graph
.
output
]
print
(
"{}{} inputs : {}"
.
format
(
prefix
,
name
,
input_names
))
print
(
"{}{} input shapes : {}"
.
format
(
prefix
,
name
,
input_shapes
))
print
(
"{}{} outputs: {}"
.
format
(
prefix
,
name
,
output_names
))
print
(
"{}{} output shapes : {}"
.
format
(
prefix
,
name
,
output_shapes
))
def
export_encoder
(
asr_model
,
args
):
print
(
"Stage-1: export encoder"
)
encoder
=
asr_model
.
encoder
encoder
.
forward
=
encoder
.
forward_chunk
encoder_outpath
=
os
.
path
.
join
(
args
[
'output_dir'
],
'encoder.onnx'
)
print
(
"
\t
Stage-1.1: prepare inputs for encoder"
)
chunk
=
torch
.
randn
(
(
args
[
'batch'
],
args
[
'decoding_window'
],
args
[
'feature_size'
]))
offset
=
0
# NOTE(xcsong): The uncertainty of `next_cache_start` only appears
# in the first few chunks, this is caused by dynamic att_cache shape, i,e
# (0, 0, 0, 0) for 1st chunk and (elayers, head, ?, d_k*2) for subsequent
# chunks. One way to ease the ONNX export is to keep `next_cache_start`
# as a fixed value. To do this, for the **first** chunk, if
# left_chunks > 0, we feed real cache & real mask to the model, otherwise
# fake cache & fake mask. In this way, we get:
# 1. 16/-1 mode: next_cache_start == 0 for all chunks
# 2. 16/4 mode: next_cache_start == chunk_size for all chunks
# 3. 16/0 mode: next_cache_start == chunk_size for all chunks
# 4. -1/-1 mode: next_cache_start == 0 for all chunks
# NO MORE DYNAMIC CHANGES!!
#
# NOTE(Mddct): We retain the current design for the convenience of supporting some
# inference frameworks without dynamic shapes. If you're interested in all-in-one
# model that supports different chunks please see:
# https://github.com/wenet-e2e/wenet/pull/1174
if
args
[
'left_chunks'
]
>
0
:
# 16/4
required_cache_size
=
args
[
'chunk_size'
]
*
args
[
'left_chunks'
]
offset
=
required_cache_size
# Real cache
att_cache
=
torch
.
zeros
(
(
args
[
'num_blocks'
],
args
[
'head'
],
required_cache_size
,
args
[
'output_size'
]
//
args
[
'head'
]
*
2
))
# Real mask
att_mask
=
torch
.
ones
(
(
args
[
'batch'
],
1
,
required_cache_size
+
args
[
'chunk_size'
]),
dtype
=
torch
.
bool
)
att_mask
[:,
:,
:
required_cache_size
]
=
0
elif
args
[
'left_chunks'
]
<=
0
:
# 16/-1, -1/-1, 16/0
required_cache_size
=
-
1
if
args
[
'left_chunks'
]
<
0
else
0
# Fake cache
att_cache
=
torch
.
zeros
(
(
args
[
'num_blocks'
],
args
[
'head'
],
0
,
args
[
'output_size'
]
//
args
[
'head'
]
*
2
))
# Fake mask
att_mask
=
torch
.
ones
((
0
,
0
,
0
),
dtype
=
torch
.
bool
)
cnn_cache
=
torch
.
zeros
(
(
args
[
'num_blocks'
],
args
[
'batch'
],
args
[
'output_size'
],
args
[
'cnn_module_kernel'
]
-
1
))
inputs
=
(
chunk
,
offset
,
required_cache_size
,
att_cache
,
cnn_cache
,
att_mask
)
print
(
"
\t\t
chunk.size(): {}
\n
"
.
format
(
chunk
.
size
()),
"
\t\t
offset: {}
\n
"
.
format
(
offset
),
"
\t\t
required_cache: {}
\n
"
.
format
(
required_cache_size
),
"
\t\t
att_cache.size(): {}
\n
"
.
format
(
att_cache
.
size
()),
"
\t\t
cnn_cache.size(): {}
\n
"
.
format
(
cnn_cache
.
size
()),
"
\t\t
att_mask.size(): {}
\n
"
.
format
(
att_mask
.
size
()))
print
(
"
\t
Stage-1.2: torch.onnx.export"
)
dynamic_axes
=
{
'chunk'
:
{
1
:
'T'
},
'att_cache'
:
{
2
:
'T_CACHE'
},
'att_mask'
:
{
2
:
'T_ADD_T_CACHE'
},
'output'
:
{
1
:
'T'
},
'r_att_cache'
:
{
2
:
'T_CACHE'
},
}
# NOTE(xcsong): We keep dynamic axes even if in 16/4 mode, this is
# to avoid padding the last chunk (which usually contains less
# frames than required). For users who want static axes, just pop
# out specific axis.
# if args['chunk_size'] > 0: # 16/4, 16/-1, 16/0
# dynamic_axes.pop('chunk')
# dynamic_axes.pop('output')
# if args['left_chunks'] >= 0: # 16/4, 16/0
# # NOTE(xsong): since we feed real cache & real mask into the
# # model when left_chunks > 0, the shape of cache will never
# # be changed.
# dynamic_axes.pop('att_cache')
# dynamic_axes.pop('r_att_cache')
torch
.
onnx
.
export
(
encoder
,
inputs
,
encoder_outpath
,
opset_version
=
13
,
export_params
=
True
,
do_constant_folding
=
True
,
input_names
=
[
'chunk'
,
'offset'
,
'required_cache_size'
,
'att_cache'
,
'cnn_cache'
,
'att_mask'
],
output_names
=
[
'output'
,
'r_att_cache'
,
'r_cnn_cache'
],
dynamic_axes
=
dynamic_axes
,
verbose
=
False
)
onnx_encoder
=
onnx
.
load
(
encoder_outpath
)
for
(
k
,
v
)
in
args
.
items
():
meta
=
onnx_encoder
.
metadata_props
.
add
()
meta
.
key
,
meta
.
value
=
str
(
k
),
str
(
v
)
onnx
.
checker
.
check_model
(
onnx_encoder
)
onnx
.
helper
.
printable_graph
(
onnx_encoder
.
graph
)
# NOTE(xcsong): to add those metadatas we need to reopen
# the file and resave it.
onnx
.
save
(
onnx_encoder
,
encoder_outpath
)
print_input_output_info
(
onnx_encoder
,
"onnx_encoder"
)
# Dynamic quantization
model_fp32
=
encoder_outpath
model_quant
=
os
.
path
.
join
(
args
[
'output_dir'
],
'encoder.quant.onnx'
)
quantize_dynamic
(
model_fp32
,
model_quant
,
weight_type
=
QuantType
.
QUInt8
)
print
(
'
\t\t
Export onnx_encoder, done! see {}'
.
format
(
encoder_outpath
))
print
(
"
\t
Stage-1.3: check onnx_encoder and torch_encoder"
)
torch_output
=
[]
torch_chunk
=
copy
.
deepcopy
(
chunk
)
torch_offset
=
copy
.
deepcopy
(
offset
)
torch_required_cache_size
=
copy
.
deepcopy
(
required_cache_size
)
torch_att_cache
=
copy
.
deepcopy
(
att_cache
)
torch_cnn_cache
=
copy
.
deepcopy
(
cnn_cache
)
torch_att_mask
=
copy
.
deepcopy
(
att_mask
)
for
i
in
range
(
10
):
print
(
"
\t\t
torch chunk-{}: {}, offset: {}, att_cache: {},"
" cnn_cache: {}, att_mask: {}"
.
format
(
i
,
list
(
torch_chunk
.
size
()),
torch_offset
,
list
(
torch_att_cache
.
size
()),
list
(
torch_cnn_cache
.
size
()),
list
(
torch_att_mask
.
size
())))
# NOTE(xsong): att_mask of the first few batches need changes if
# we use 16/4 mode.
if
args
[
'left_chunks'
]
>
0
:
# 16/4
torch_att_mask
[:,
:,
-
(
args
[
'chunk_size'
]
*
(
i
+
1
)):]
=
1
out
,
torch_att_cache
,
torch_cnn_cache
=
encoder
(
torch_chunk
,
torch_offset
,
torch_required_cache_size
,
torch_att_cache
,
torch_cnn_cache
,
torch_att_mask
)
torch_output
.
append
(
out
)
torch_offset
+=
out
.
size
(
1
)
torch_output
=
torch
.
cat
(
torch_output
,
dim
=
1
)
onnx_output
=
[]
onnx_chunk
=
to_numpy
(
chunk
)
onnx_offset
=
np
.
array
((
offset
)).
astype
(
np
.
int64
)
onnx_required_cache_size
=
np
.
array
((
required_cache_size
)).
astype
(
np
.
int64
)
onnx_att_cache
=
to_numpy
(
att_cache
)
onnx_cnn_cache
=
to_numpy
(
cnn_cache
)
onnx_att_mask
=
to_numpy
(
att_mask
)
ort_session
=
onnxruntime
.
InferenceSession
(
encoder_outpath
)
input_names
=
[
node
.
name
for
node
in
onnx_encoder
.
graph
.
input
]
for
i
in
range
(
10
):
print
(
"
\t\t
onnx chunk-{}: {}, offset: {}, att_cache: {},"
" cnn_cache: {}, att_mask: {}"
.
format
(
i
,
onnx_chunk
.
shape
,
onnx_offset
,
onnx_att_cache
.
shape
,
onnx_cnn_cache
.
shape
,
onnx_att_mask
.
shape
))
# NOTE(xsong): att_mask of the first few batches need changes if
# we use 16/4 mode.
if
args
[
'left_chunks'
]
>
0
:
# 16/4
onnx_att_mask
[:,
:,
-
(
args
[
'chunk_size'
]
*
(
i
+
1
)):]
=
1
ort_inputs
=
{
'chunk'
:
onnx_chunk
,
'offset'
:
onnx_offset
,
'required_cache_size'
:
onnx_required_cache_size
,
'att_cache'
:
onnx_att_cache
,
'cnn_cache'
:
onnx_cnn_cache
,
'att_mask'
:
onnx_att_mask
}
# NOTE(xcsong): If we use 16/-1, -1/-1 or 16/0 mode, `next_cache_start`
# will be hardcoded to 0 or chunk_size by ONNX, thus
# required_cache_size and att_mask are no more needed and they will
# be removed by ONNX automatically.
for
k
in
list
(
ort_inputs
):
if
k
not
in
input_names
:
ort_inputs
.
pop
(
k
)
ort_outs
=
ort_session
.
run
(
None
,
ort_inputs
)
onnx_att_cache
,
onnx_cnn_cache
=
ort_outs
[
1
],
ort_outs
[
2
]
onnx_output
.
append
(
ort_outs
[
0
])
onnx_offset
+=
ort_outs
[
0
].
shape
[
1
]
onnx_output
=
np
.
concatenate
(
onnx_output
,
axis
=
1
)
np
.
testing
.
assert_allclose
(
to_numpy
(
torch_output
),
onnx_output
,
rtol
=
1e-03
,
atol
=
1e-05
)
meta
=
ort_session
.
get_modelmeta
()
print
(
"
\t\t
custom_metadata_map={}"
.
format
(
meta
.
custom_metadata_map
))
print
(
"
\t\t
Check onnx_encoder, pass!"
)
def
export_ctc
(
asr_model
,
args
):
print
(
"Stage-2: export ctc"
)
ctc
=
asr_model
.
ctc
ctc
.
forward
=
ctc
.
log_softmax
ctc_outpath
=
os
.
path
.
join
(
args
[
'output_dir'
],
'ctc.onnx'
)
print
(
"
\t
Stage-2.1: prepare inputs for ctc"
)
hidden
=
torch
.
randn
(
(
args
[
'batch'
],
args
[
'chunk_size'
]
if
args
[
'chunk_size'
]
>
0
else
16
,
args
[
'output_size'
]))
print
(
"
\t
Stage-2.2: torch.onnx.export"
)
dynamic_axes
=
{
'hidden'
:
{
1
:
'T'
},
'probs'
:
{
1
:
'T'
}}
torch
.
onnx
.
export
(
ctc
,
hidden
,
ctc_outpath
,
opset_version
=
13
,
export_params
=
True
,
do_constant_folding
=
True
,
input_names
=
[
'hidden'
],
output_names
=
[
'probs'
],
dynamic_axes
=
dynamic_axes
,
verbose
=
False
)
onnx_ctc
=
onnx
.
load
(
ctc_outpath
)
for
(
k
,
v
)
in
args
.
items
():
meta
=
onnx_ctc
.
metadata_props
.
add
()
meta
.
key
,
meta
.
value
=
str
(
k
),
str
(
v
)
onnx
.
checker
.
check_model
(
onnx_ctc
)
onnx
.
helper
.
printable_graph
(
onnx_ctc
.
graph
)
onnx
.
save
(
onnx_ctc
,
ctc_outpath
)
print_input_output_info
(
onnx_ctc
,
"onnx_ctc"
)
# Dynamic quantization
model_fp32
=
ctc_outpath
model_quant
=
os
.
path
.
join
(
args
[
'output_dir'
],
'ctc.quant.onnx'
)
quantize_dynamic
(
model_fp32
,
model_quant
,
weight_type
=
QuantType
.
QUInt8
)
print
(
'
\t\t
Export onnx_ctc, done! see {}'
.
format
(
ctc_outpath
))
print
(
"
\t
Stage-2.3: check onnx_ctc and torch_ctc"
)
torch_output
=
ctc
(
hidden
)
ort_session
=
onnxruntime
.
InferenceSession
(
ctc_outpath
)
onnx_output
=
ort_session
.
run
(
None
,
{
'hidden'
:
to_numpy
(
hidden
)})
np
.
testing
.
assert_allclose
(
to_numpy
(
torch_output
),
onnx_output
[
0
],
rtol
=
1e-03
,
atol
=
1e-05
)
print
(
"
\t\t
Check onnx_ctc, pass!"
)
def
export_decoder
(
asr_model
,
args
):
print
(
"Stage-3: export decoder"
)
decoder
=
asr_model
# NOTE(lzhin): parameters of encoder will be automatically removed
# since they are not used during rescoring.
decoder
.
forward
=
decoder
.
forward_attention_decoder
decoder_outpath
=
os
.
path
.
join
(
args
[
'output_dir'
],
'decoder.onnx'
)
print
(
"
\t
Stage-3.1: prepare inputs for decoder"
)
# hardcode time->200 nbest->10 len->20, they are dynamic axes.
encoder_out
=
torch
.
randn
((
1
,
200
,
args
[
'output_size'
]))
hyps
=
torch
.
randint
(
low
=
0
,
high
=
args
[
'vocab_size'
],
size
=
[
10
,
20
])
hyps
[:,
0
]
=
args
[
'vocab_size'
]
-
1
# <sos>
hyps_lens
=
torch
.
randint
(
low
=
15
,
high
=
21
,
size
=
[
10
])
print
(
"
\t
Stage-3.2: torch.onnx.export"
)
dynamic_axes
=
{
'hyps'
:
{
0
:
'NBEST'
,
1
:
'L'
},
'hyps_lens'
:
{
0
:
'NBEST'
},
'encoder_out'
:
{
1
:
'T'
},
'score'
:
{
0
:
'NBEST'
,
1
:
'L'
},
'r_score'
:
{
0
:
'NBEST'
,
1
:
'L'
}
}
inputs
=
(
hyps
,
hyps_lens
,
encoder_out
,
args
[
'reverse_weight'
])
torch
.
onnx
.
export
(
decoder
,
inputs
,
decoder_outpath
,
opset_version
=
13
,
export_params
=
True
,
do_constant_folding
=
True
,
input_names
=
[
'hyps'
,
'hyps_lens'
,
'encoder_out'
,
'reverse_weight'
],
output_names
=
[
'score'
,
'r_score'
],
dynamic_axes
=
dynamic_axes
,
verbose
=
False
)
onnx_decoder
=
onnx
.
load
(
decoder_outpath
)
for
(
k
,
v
)
in
args
.
items
():
meta
=
onnx_decoder
.
metadata_props
.
add
()
meta
.
key
,
meta
.
value
=
str
(
k
),
str
(
v
)
onnx
.
checker
.
check_model
(
onnx_decoder
)
onnx
.
helper
.
printable_graph
(
onnx_decoder
.
graph
)
onnx
.
save
(
onnx_decoder
,
decoder_outpath
)
print_input_output_info
(
onnx_decoder
,
"onnx_decoder"
)
model_fp32
=
decoder_outpath
model_quant
=
os
.
path
.
join
(
args
[
'output_dir'
],
'decoder.quant.onnx'
)
quantize_dynamic
(
model_fp32
,
model_quant
,
weight_type
=
QuantType
.
QUInt8
)
print
(
'
\t\t
Export onnx_decoder, done! see {}'
.
format
(
decoder_outpath
))
print
(
"
\t
Stage-3.3: check onnx_decoder and torch_decoder"
)
torch_score
,
torch_r_score
=
decoder
(
hyps
,
hyps_lens
,
encoder_out
,
args
[
'reverse_weight'
])
ort_session
=
onnxruntime
.
InferenceSession
(
decoder_outpath
)
input_names
=
[
node
.
name
for
node
in
onnx_decoder
.
graph
.
input
]
ort_inputs
=
{
'hyps'
:
to_numpy
(
hyps
),
'hyps_lens'
:
to_numpy
(
hyps_lens
),
'encoder_out'
:
to_numpy
(
encoder_out
),
'reverse_weight'
:
np
.
array
((
args
[
'reverse_weight'
])),
}
for
k
in
list
(
ort_inputs
):
if
k
not
in
input_names
:
ort_inputs
.
pop
(
k
)
onnx_output
=
ort_session
.
run
(
None
,
ort_inputs
)
np
.
testing
.
assert_allclose
(
to_numpy
(
torch_score
),
onnx_output
[
0
],
rtol
=
1e-03
,
atol
=
1e-05
)
if
args
[
'is_bidirectional_decoder'
]
and
args
[
'reverse_weight'
]
>
0.0
:
np
.
testing
.
assert_allclose
(
to_numpy
(
torch_r_score
),
onnx_output
[
1
],
rtol
=
1e-03
,
atol
=
1e-05
)
print
(
"
\t\t
Check onnx_decoder, pass!"
)
def
main
():
torch
.
manual_seed
(
777
)
args
=
get_args
()
output_dir
=
args
.
output_dir
os
.
system
(
"mkdir -p "
+
output_dir
)
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'-1'
with
open
(
args
.
config
,
'r'
)
as
fin
:
configs
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
FullLoader
)
model
=
init_model
(
configs
)
load_checkpoint
(
model
,
args
.
checkpoint
)
model
.
eval
()
print
(
model
)
arguments
=
{}
arguments
[
'output_dir'
]
=
output_dir
arguments
[
'batch'
]
=
1
arguments
[
'chunk_size'
]
=
args
.
chunk_size
arguments
[
'left_chunks'
]
=
args
.
num_decoding_left_chunks
arguments
[
'reverse_weight'
]
=
args
.
reverse_weight
arguments
[
'output_size'
]
=
configs
[
'encoder_conf'
][
'output_size'
]
arguments
[
'num_blocks'
]
=
configs
[
'encoder_conf'
][
'num_blocks'
]
arguments
[
'cnn_module_kernel'
]
=
configs
[
'encoder_conf'
].
get
(
'cnn_module_kernel'
,
1
)
arguments
[
'head'
]
=
configs
[
'encoder_conf'
][
'attention_heads'
]
arguments
[
'feature_size'
]
=
configs
[
'input_dim'
]
arguments
[
'vocab_size'
]
=
configs
[
'output_dim'
]
# NOTE(xcsong): if chunk_size == -1, hardcode to 67
arguments
[
'decoding_window'
]
=
(
args
.
chunk_size
-
1
)
*
\
model
.
encoder
.
embed
.
subsampling_rate
+
\
model
.
encoder
.
embed
.
right_context
+
1
if
args
.
chunk_size
>
0
else
67
arguments
[
'encoder'
]
=
configs
[
'encoder'
]
arguments
[
'decoder'
]
=
configs
[
'decoder'
]
arguments
[
'subsampling_rate'
]
=
model
.
subsampling_rate
()
arguments
[
'right_context'
]
=
model
.
right_context
()
arguments
[
'sos_symbol'
]
=
model
.
sos_symbol
()
arguments
[
'eos_symbol'
]
=
model
.
eos_symbol
()
arguments
[
'is_bidirectional_decoder'
]
=
1
\
if
model
.
is_bidirectional_decoder
()
else
0
# NOTE(xcsong): Please note that -1/-1 means non-streaming model! It is
# not a [16/4 16/-1 16/0] all-in-one model and it should not be used in
# streaming mode (i.e., setting chunk_size=16 in `decoder_main`). If you
# want to use 16/-1 or any other streaming mode in `decoder_main`,
# please export onnx in the same config.
if
arguments
[
'left_chunks'
]
>
0
:
assert
arguments
[
'chunk_size'
]
>
0
# -1/4 not supported
export_encoder
(
model
,
arguments
)
export_ctc
(
model
,
arguments
)
export_decoder
(
model
,
arguments
)
if
__name__
==
'__main__'
:
main
()
examples/aishell/s0/wenet/bin/export_onnx_gpu.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
argparse
import
os
import
sys
import
torch
import
yaml
import
logging
from
wenet.utils.checkpoint
import
load_checkpoint
from
wenet.transformer.ctc
import
CTC
from
wenet.transformer.decoder
import
TransformerDecoder
from
wenet.transformer.encoder
import
BaseEncoder
from
wenet.utils.init_model
import
init_model
from
wenet.utils.mask
import
make_pad_mask
try
:
import
onnxruntime
except
ImportError
:
print
(
'Please install onnxruntime-gpu!'
)
sys
.
exit
(
1
)
logger
=
logging
.
getLogger
(
__file__
)
logger
.
setLevel
(
logging
.
INFO
)
class
Encoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
encoder
:
BaseEncoder
,
ctc
:
CTC
,
beam_size
:
int
=
10
):
super
().
__init__
()
self
.
encoder
=
encoder
self
.
ctc
=
ctc
self
.
beam_size
=
beam_size
def
forward
(
self
,
speech
:
torch
.
Tensor
,
speech_lengths
:
torch
.
Tensor
,):
"""Encoder
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
Returns:
encoder_out: B x T x F
encoder_out_lens: B
ctc_log_probs: B x T x V
beam_log_probs: B x T x beam_size
beam_log_probs_idx: B x T x beam_size
"""
encoder_out
,
encoder_mask
=
self
.
encoder
(
speech
,
speech_lengths
,
-
1
,
-
1
)
encoder_out_lens
=
encoder_mask
.
squeeze
(
1
).
sum
(
1
)
ctc_log_probs
=
self
.
ctc
.
log_softmax
(
encoder_out
)
encoder_out_lens
=
encoder_out_lens
.
int
()
beam_log_probs
,
beam_log_probs_idx
=
torch
.
topk
(
ctc_log_probs
,
self
.
beam_size
,
dim
=
2
)
return
encoder_out
,
encoder_out_lens
,
ctc_log_probs
,
\
beam_log_probs
,
beam_log_probs_idx
class
StreamingEncoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
,
required_cache_size
,
beam_size
,
transformer
=
False
):
super
().
__init__
()
self
.
ctc
=
model
.
ctc
self
.
subsampling_rate
=
model
.
encoder
.
embed
.
subsampling_rate
self
.
embed
=
model
.
encoder
.
embed
self
.
global_cmvn
=
model
.
encoder
.
global_cmvn
self
.
required_cache_size
=
required_cache_size
self
.
beam_size
=
beam_size
self
.
encoder
=
model
.
encoder
self
.
transformer
=
transformer
def
forward
(
self
,
chunk_xs
,
chunk_lens
,
offset
,
att_cache
,
cnn_cache
,
cache_mask
):
"""Streaming Encoder
Args:
xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate +
\
subsample.right_context + 1`
offset (torch.Tensor): offset with shape (b, 1)
1 is retained for triton deployment
required_cache_size (int): cache size required for next chunk
compuation
> 0: actual cache size
<= 0: not allowed in streaming gpu encoder `
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(b, elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(b, elayers, b, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
in a batch of request, each request may have different
history cache. Cache mask is used to indidate the effective
cache for each request
Returns:
torch.Tensor: log probabilities of ctc output and cutoff by beam size
with shape (b, chunk_size, beam)
torch.Tensor: index of top beam size probabilities for each timestep
with shape (b, chunk_size, beam)
torch.Tensor: output of current input xs,
with shape (b, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
same shape (b, elayers, head, cache_t1, d_k * 2)
as the original att_cache
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
torch.Tensor: new cache mask, with same shape as the original
cache mask
"""
offset
=
offset
.
squeeze
(
1
)
T
=
chunk_xs
.
size
(
1
)
chunk_mask
=
~
make_pad_mask
(
chunk_lens
,
T
).
unsqueeze
(
1
)
# B X 1 X T
chunk_mask
=
chunk_mask
.
to
(
chunk_xs
.
dtype
)
# transpose batch & num_layers dim
att_cache
=
torch
.
transpose
(
att_cache
,
0
,
1
)
cnn_cache
=
torch
.
transpose
(
cnn_cache
,
0
,
1
)
# rewrite encoder.forward_chunk
# <---------forward_chunk START--------->
xs
=
self
.
global_cmvn
(
chunk_xs
)
# chunk mask is important for batch inferencing since
# different sequence in a batch has different length
xs
,
pos_emb
,
chunk_mask
=
self
.
embed
(
xs
,
chunk_mask
,
offset
)
cache_size
=
att_cache
.
size
(
3
)
# required cache size
masks
=
torch
.
cat
((
cache_mask
,
chunk_mask
),
dim
=
2
)
index
=
offset
-
cache_size
pos_emb
=
self
.
embed
.
position_encoding
(
index
,
cache_size
+
xs
.
size
(
1
))
pos_emb
=
pos_emb
.
to
(
dtype
=
xs
.
dtype
)
next_cache_start
=
-
self
.
required_cache_size
r_cache_mask
=
masks
[:,
:,
next_cache_start
:]
r_att_cache
=
[]
r_cnn_cache
=
[]
for
i
,
layer
in
enumerate
(
self
.
encoder
.
encoders
):
xs
,
_
,
new_att_cache
,
new_cnn_cache
=
layer
(
xs
,
masks
,
pos_emb
,
att_cache
=
att_cache
[
i
],
cnn_cache
=
cnn_cache
[
i
])
# shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (B, hidden-dim, cache_t2)
r_att_cache
.
append
(
new_att_cache
[:,
:,
next_cache_start
:,
:].
unsqueeze
(
1
))
if
not
self
.
transformer
:
r_cnn_cache
.
append
(
new_cnn_cache
.
unsqueeze
(
1
))
if
self
.
encoder
.
normalize_before
:
chunk_out
=
self
.
encoder
.
after_norm
(
xs
)
else
:
chunk_out
=
xs
r_att_cache
=
torch
.
cat
(
r_att_cache
,
dim
=
1
)
# concat on layers idx
if
not
self
.
transformer
:
r_cnn_cache
=
torch
.
cat
(
r_cnn_cache
,
dim
=
1
)
# concat on layers
# <---------forward_chunk END--------->
log_ctc_probs
=
self
.
ctc
.
log_softmax
(
chunk_out
)
log_probs
,
log_probs_idx
=
torch
.
topk
(
log_ctc_probs
,
self
.
beam_size
,
dim
=
2
)
log_probs
=
log_probs
.
to
(
chunk_xs
.
dtype
)
r_offset
=
offset
+
chunk_out
.
shape
[
1
]
# the below ops not supported in Tensorrt
# chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
# rounding_mode='floor')
chunk_out_lens
=
chunk_lens
//
self
.
subsampling_rate
r_offset
=
r_offset
.
unsqueeze
(
1
)
return
log_probs
,
log_probs_idx
,
chunk_out
,
chunk_out_lens
,
\
r_offset
,
r_att_cache
,
r_cnn_cache
,
r_cache_mask
class
StreamingSqueezeformerEncoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
,
required_cache_size
,
beam_size
):
super
().
__init__
()
self
.
ctc
=
model
.
ctc
self
.
subsampling_rate
=
model
.
encoder
.
embed
.
subsampling_rate
self
.
embed
=
model
.
encoder
.
embed
self
.
global_cmvn
=
model
.
encoder
.
global_cmvn
self
.
required_cache_size
=
required_cache_size
self
.
beam_size
=
beam_size
self
.
encoder
=
model
.
encoder
self
.
reduce_idx
=
model
.
encoder
.
reduce_idx
self
.
recover_idx
=
model
.
encoder
.
recover_idx
if
self
.
reduce_idx
is
None
:
self
.
time_reduce
=
None
else
:
if
self
.
recover_idx
is
None
:
self
.
time_reduce
=
'normal'
# no recovery at the end
else
:
self
.
time_reduce
=
'recover'
# recovery at the end
assert
len
(
self
.
reduce_idx
)
==
len
(
self
.
recover_idx
)
def
calculate_downsampling_factor
(
self
,
i
:
int
)
->
int
:
if
self
.
reduce_idx
is
None
:
return
1
else
:
reduce_exp
,
recover_exp
=
0
,
0
for
exp
,
rd_idx
in
enumerate
(
self
.
reduce_idx
):
if
i
>=
rd_idx
:
reduce_exp
=
exp
+
1
if
self
.
recover_idx
is
not
None
:
for
exp
,
rc_idx
in
enumerate
(
self
.
recover_idx
):
if
i
>=
rc_idx
:
recover_exp
=
exp
+
1
return
int
(
2
**
(
reduce_exp
-
recover_exp
))
def
forward
(
self
,
chunk_xs
,
chunk_lens
,
offset
,
att_cache
,
cnn_cache
,
cache_mask
):
"""Streaming Encoder
Args:
xs (torch.Tensor): chunk input, with shape (b, time, mel-dim),
where `time == (chunk_size - 1) * subsample_rate +
\
subsample.right_context + 1`
offset (torch.Tensor): offset with shape (b, 1)
1 is retained for triton deployment
required_cache_size (int): cache size required for next chunk
compuation
> 0: actual cache size
<= 0: not allowed in streaming gpu encoder `
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
transformer/conformer attention, with shape
(b, elayers, head, cache_t1, d_k * 2), where
`head * d_k == hidden-dim` and
`cache_t1 == chunk_size * num_decoding_left_chunks`.
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
(b, elayers, b, hidden-dim, cache_t2), where
`cache_t2 == cnn.lorder - 1`
cache_mask: (torch.Tensor): cache mask with shape (b, required_cache_size)
in a batch of request, each request may have different
history cache. Cache mask is used to indidate the effective
cache for each request
Returns:
torch.Tensor: log probabilities of ctc output and cutoff by beam size
with shape (b, chunk_size, beam)
torch.Tensor: index of top beam size probabilities for each timestep
with shape (b, chunk_size, beam)
torch.Tensor: output of current input xs,
with shape (b, chunk_size, hidden-dim).
torch.Tensor: new attention cache required for next chunk, with
same shape (b, elayers, head, cache_t1, d_k * 2)
as the original att_cache
torch.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
torch.Tensor: new cache mask, with same shape as the original
cache mask
"""
offset
=
offset
.
squeeze
(
1
)
T
=
chunk_xs
.
size
(
1
)
chunk_mask
=
~
make_pad_mask
(
chunk_lens
,
T
).
unsqueeze
(
1
)
# B X 1 X T
chunk_mask
=
chunk_mask
.
to
(
chunk_xs
.
dtype
)
# transpose batch & num_layers dim
att_cache
=
torch
.
transpose
(
att_cache
,
0
,
1
)
cnn_cache
=
torch
.
transpose
(
cnn_cache
,
0
,
1
)
# rewrite encoder.forward_chunk
# <---------forward_chunk START--------->
xs
=
self
.
global_cmvn
(
chunk_xs
)
# chunk mask is important for batch inferencing since
# different sequence in a batch has different length
xs
,
pos_emb
,
chunk_mask
=
self
.
embed
(
xs
,
chunk_mask
,
offset
)
elayers
,
cache_size
=
att_cache
.
size
(
0
),
att_cache
.
size
(
3
)
att_mask
=
torch
.
cat
((
cache_mask
,
chunk_mask
),
dim
=
2
)
index
=
offset
-
cache_size
pos_emb
=
self
.
embed
.
position_encoding
(
index
,
cache_size
+
xs
.
size
(
1
))
pos_emb
=
pos_emb
.
to
(
dtype
=
xs
.
dtype
)
next_cache_start
=
-
self
.
required_cache_size
r_cache_mask
=
att_mask
[:,
:,
next_cache_start
:]
r_att_cache
=
[]
r_cnn_cache
=
[]
mask_pad
=
torch
.
ones
(
1
,
xs
.
size
(
1
),
device
=
xs
.
device
,
dtype
=
torch
.
bool
)
mask_pad
=
mask_pad
.
unsqueeze
(
1
)
max_att_len
:
int
=
0
recover_activations
:
\
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]]
=
[]
index
=
0
xs_lens
=
torch
.
tensor
([
xs
.
size
(
1
)],
device
=
xs
.
device
,
dtype
=
torch
.
int
)
xs
=
self
.
encoder
.
preln
(
xs
)
for
i
,
layer
in
enumerate
(
self
.
encoder
.
encoders
):
if
self
.
reduce_idx
is
not
None
:
if
self
.
time_reduce
is
not
None
and
i
in
self
.
reduce_idx
:
recover_activations
.
append
(
(
xs
,
att_mask
,
pos_emb
,
mask_pad
))
xs
,
xs_lens
,
att_mask
,
mask_pad
=
\
self
.
encoder
.
time_reduction_layer
(
xs
,
xs_lens
,
att_mask
,
mask_pad
)
pos_emb
=
pos_emb
[:,
::
2
,
:]
if
self
.
encoder
.
pos_enc_layer_type
==
"rel_pos_repaired"
:
pos_emb
=
pos_emb
[:,
:
xs
.
size
(
1
)
*
2
-
1
,
:]
index
+=
1
if
self
.
recover_idx
is
not
None
:
if
self
.
time_reduce
==
'recover'
and
i
in
self
.
recover_idx
:
index
-=
1
(
recover_tensor
,
recover_att_mask
,
recover_pos_emb
,
recover_mask_pad
)
\
=
recover_activations
[
index
]
# recover output length for ctc decode
xs
=
xs
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
2
,
1
).
flatten
(
1
,
2
)
xs
=
self
.
encoder
.
time_recover_layer
(
xs
)
recoverd_t
=
recover_tensor
.
size
(
1
)
xs
=
recover_tensor
+
xs
[:,
:
recoverd_t
,
:].
contiguous
()
att_mask
=
recover_att_mask
pos_emb
=
recover_pos_emb
mask_pad
=
recover_mask_pad
factor
=
self
.
calculate_downsampling_factor
(
i
)
xs
,
_
,
new_att_cache
,
new_cnn_cache
=
layer
(
xs
,
att_mask
,
pos_emb
,
att_cache
=
att_cache
[
i
][:,
:,
::
factor
,
:]
[:,
:,
:
pos_emb
.
size
(
1
)
-
xs
.
size
(
1
),
:]
if
elayers
>
0
else
att_cache
[:,
:,
::
factor
,
:],
cnn_cache
=
cnn_cache
[
i
]
if
cnn_cache
.
size
(
0
)
>
0
else
cnn_cache
)
cached_att
\
=
new_att_cache
[:,
:,
next_cache_start
//
factor
:,
:]
cached_cnn
=
new_cnn_cache
.
unsqueeze
(
1
)
cached_att
=
cached_att
.
unsqueeze
(
3
).
\
repeat
(
1
,
1
,
1
,
factor
,
1
).
flatten
(
2
,
3
)
if
i
==
0
:
# record length for the first block as max length
max_att_len
=
cached_att
.
size
(
2
)
r_att_cache
.
append
(
cached_att
[:,
:,
:
max_att_len
,
:].
unsqueeze
(
1
))
r_cnn_cache
.
append
(
cached_cnn
)
chunk_out
=
xs
r_att_cache
=
torch
.
cat
(
r_att_cache
,
dim
=
1
)
# concat on layers idx
r_cnn_cache
=
torch
.
cat
(
r_cnn_cache
,
dim
=
1
)
# concat on layers
# <---------forward_chunk END--------->
log_ctc_probs
=
self
.
ctc
.
log_softmax
(
chunk_out
)
log_probs
,
log_probs_idx
=
torch
.
topk
(
log_ctc_probs
,
self
.
beam_size
,
dim
=
2
)
log_probs
=
log_probs
.
to
(
chunk_xs
.
dtype
)
r_offset
=
offset
+
chunk_out
.
shape
[
1
]
# the below ops not supported in Tensorrt
# chunk_out_lens = torch.div(chunk_lens, subsampling_rate,
# rounding_mode='floor')
chunk_out_lens
=
chunk_lens
//
self
.
subsampling_rate
r_offset
=
r_offset
.
unsqueeze
(
1
)
return
log_probs
,
log_probs_idx
,
chunk_out
,
chunk_out_lens
,
\
r_offset
,
r_att_cache
,
r_cnn_cache
,
r_cache_mask
class
Decoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
decoder
:
TransformerDecoder
,
ctc_weight
:
float
=
0.5
,
reverse_weight
:
float
=
0.0
,
beam_size
:
int
=
10
,
decoder_fastertransformer
:
bool
=
False
):
super
().
__init__
()
self
.
decoder
=
decoder
self
.
ctc_weight
=
ctc_weight
self
.
reverse_weight
=
reverse_weight
self
.
beam_size
=
beam_size
self
.
decoder_fastertransformer
=
decoder_fastertransformer
def
forward
(
self
,
encoder_out
:
torch
.
Tensor
,
encoder_lens
:
torch
.
Tensor
,
hyps_pad_sos_eos
:
torch
.
Tensor
,
hyps_lens_sos
:
torch
.
Tensor
,
r_hyps_pad_sos_eos
:
torch
.
Tensor
,
ctc_score
:
torch
.
Tensor
):
"""Encoder
Args:
encoder_out: B x T x F
encoder_lens: B
hyps_pad_sos_eos: B x beam x (T2+1),
hyps with sos & eos and padded by ignore id
hyps_lens_sos: B x beam, length for each hyp with sos
r_hyps_pad_sos_eos: B x beam x (T2+1),
reversed hyps with sos & eos and padded by ignore id
ctc_score: B x beam, ctc score for each hyp
Returns:
decoder_out: B x beam x T2 x V
r_decoder_out: B x beam x T2 x V
best_index: B
"""
B
,
T
,
F
=
encoder_out
.
shape
bz
=
self
.
beam_size
B2
=
B
*
bz
encoder_out
=
encoder_out
.
repeat
(
1
,
bz
,
1
).
view
(
B2
,
T
,
F
)
encoder_mask
=
~
make_pad_mask
(
encoder_lens
,
T
).
unsqueeze
(
1
)
encoder_mask
=
encoder_mask
.
repeat
(
1
,
bz
,
1
).
view
(
B2
,
1
,
T
)
T2
=
hyps_pad_sos_eos
.
shape
[
2
]
-
1
hyps_pad
=
hyps_pad_sos_eos
.
view
(
B2
,
T2
+
1
)
hyps_lens
=
hyps_lens_sos
.
view
(
B2
,)
hyps_pad_sos
=
hyps_pad
[:,
:
-
1
].
contiguous
()
hyps_pad_eos
=
hyps_pad
[:,
1
:].
contiguous
()
r_hyps_pad
=
r_hyps_pad_sos_eos
.
view
(
B2
,
T2
+
1
)
r_hyps_pad_sos
=
r_hyps_pad
[:,
:
-
1
].
contiguous
()
r_hyps_pad_eos
=
r_hyps_pad
[:,
1
:].
contiguous
()
decoder_out
,
r_decoder_out
,
_
=
self
.
decoder
(
encoder_out
,
encoder_mask
,
hyps_pad_sos
,
hyps_lens
,
r_hyps_pad_sos
,
self
.
reverse_weight
)
decoder_out
=
torch
.
nn
.
functional
.
log_softmax
(
decoder_out
,
dim
=-
1
)
V
=
decoder_out
.
shape
[
-
1
]
decoder_out
=
decoder_out
.
view
(
B2
,
T2
,
V
)
mask
=
~
make_pad_mask
(
hyps_lens
,
T2
)
# B2 x T2
# mask index, remove ignore id
index
=
torch
.
unsqueeze
(
hyps_pad_eos
*
mask
,
2
)
score
=
decoder_out
.
gather
(
2
,
index
).
squeeze
(
2
)
# B2 X T2
# mask padded part
score
=
score
*
mask
decoder_out
=
decoder_out
.
view
(
B
,
bz
,
T2
,
V
)
if
self
.
reverse_weight
>
0
:
r_decoder_out
=
torch
.
nn
.
functional
.
log_softmax
(
r_decoder_out
,
dim
=-
1
)
r_decoder_out
=
r_decoder_out
.
view
(
B2
,
T2
,
V
)
index
=
torch
.
unsqueeze
(
r_hyps_pad_eos
*
mask
,
2
)
r_score
=
r_decoder_out
.
gather
(
2
,
index
).
squeeze
(
2
)
r_score
=
r_score
*
mask
score
=
score
*
(
1
-
self
.
reverse_weight
)
+
\
self
.
reverse_weight
*
r_score
r_decoder_out
=
r_decoder_out
.
view
(
B
,
bz
,
T2
,
V
)
score
=
torch
.
sum
(
score
,
axis
=
1
)
# B2
score
=
torch
.
reshape
(
score
,
(
B
,
bz
))
+
self
.
ctc_weight
*
ctc_score
best_index
=
torch
.
argmax
(
score
,
dim
=
1
)
if
self
.
decoder_fastertransformer
:
return
decoder_out
,
best_index
else
:
return
best_index
def
to_numpy
(
tensors
):
out
=
[]
if
type
(
tensors
)
==
torch
.
tensor
:
tensors
=
[
tensors
]
for
tensor
in
tensors
:
if
tensor
.
requires_grad
:
tensor
=
tensor
.
detach
().
cpu
().
numpy
()
else
:
tensor
=
tensor
.
cpu
().
numpy
()
out
.
append
(
tensor
)
return
out
def
test
(
xlist
,
blist
,
rtol
=
1e-3
,
atol
=
1e-5
,
tolerate_small_mismatch
=
True
):
for
a
,
b
in
zip
(
xlist
,
blist
):
try
:
torch
.
testing
.
assert_allclose
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
except
AssertionError
as
error
:
if
tolerate_small_mismatch
:
print
(
error
)
else
:
raise
def
export_offline_encoder
(
model
,
configs
,
args
,
logger
,
encoder_onnx_path
):
bz
=
32
seq_len
=
100
beam_size
=
args
.
beam_size
feature_size
=
configs
[
"input_dim"
]
speech
=
torch
.
randn
(
bz
,
seq_len
,
feature_size
,
dtype
=
torch
.
float32
)
speech_lens
=
torch
.
randint
(
low
=
10
,
high
=
seq_len
,
size
=
(
bz
,),
dtype
=
torch
.
int32
)
encoder
=
Encoder
(
model
.
encoder
,
model
.
ctc
,
beam_size
)
encoder
.
eval
()
torch
.
onnx
.
export
(
encoder
,
(
speech
,
speech_lens
),
encoder_onnx_path
,
export_params
=
True
,
opset_version
=
13
,
do_constant_folding
=
True
,
input_names
=
[
'speech'
,
'speech_lengths'
],
output_names
=
[
'encoder_out'
,
'encoder_out_lens'
,
'ctc_log_probs'
,
'beam_log_probs'
,
'beam_log_probs_idx'
],
dynamic_axes
=
{
'speech'
:
{
0
:
'B'
,
1
:
'T'
},
'speech_lengths'
:
{
0
:
'B'
},
'encoder_out'
:
{
0
:
'B'
,
1
:
'T_OUT'
},
'encoder_out_lens'
:
{
0
:
'B'
},
'ctc_log_probs'
:
{
0
:
'B'
,
1
:
'T_OUT'
},
'beam_log_probs'
:
{
0
:
'B'
,
1
:
'T_OUT'
},
'beam_log_probs_idx'
:
{
0
:
'B'
,
1
:
'T_OUT'
},
},
verbose
=
False
)
with
torch
.
no_grad
():
o0
,
o1
,
o2
,
o3
,
o4
=
encoder
(
speech
,
speech_lens
)
providers
=
[
"CUDAExecutionProvider"
]
ort_session
=
onnxruntime
.
InferenceSession
(
encoder_onnx_path
,
providers
=
providers
)
ort_inputs
=
{
'speech'
:
to_numpy
(
speech
),
'speech_lengths'
:
to_numpy
(
speech_lens
)}
ort_outs
=
ort_session
.
run
(
None
,
ort_inputs
)
# check encoder output
test
(
to_numpy
([
o0
,
o1
,
o2
,
o3
,
o4
]),
ort_outs
)
logger
.
info
(
"export offline onnx encoder succeed!"
)
onnx_config
=
{
"beam_size"
:
args
.
beam_size
,
"reverse_weight"
:
args
.
reverse_weight
,
"ctc_weight"
:
args
.
ctc_weight
,
"fp16"
:
args
.
fp16
}
return
onnx_config
def
export_online_encoder
(
model
,
configs
,
args
,
logger
,
encoder_onnx_path
):
decoding_chunk_size
=
args
.
decoding_chunk_size
subsampling
=
model
.
encoder
.
embed
.
subsampling_rate
context
=
model
.
encoder
.
embed
.
right_context
+
1
decoding_window
=
(
decoding_chunk_size
-
1
)
*
subsampling
+
context
batch_size
=
32
audio_len
=
decoding_window
feature_size
=
configs
[
"input_dim"
]
output_size
=
configs
[
"encoder_conf"
][
"output_size"
]
num_layers
=
configs
[
"encoder_conf"
][
"num_blocks"
]
# in transformer the cnn module will not be available
transformer
=
False
cnn_module_kernel
=
configs
[
"encoder_conf"
].
get
(
"cnn_module_kernel"
,
1
)
-
1
if
not
cnn_module_kernel
:
transformer
=
True
num_decoding_left_chunks
=
args
.
num_decoding_left_chunks
required_cache_size
=
decoding_chunk_size
*
num_decoding_left_chunks
if
configs
[
'encoder'
]
==
'squeezeformer'
:
encoder
=
StreamingSqueezeformerEncoder
(
model
,
required_cache_size
,
args
.
beam_size
)
else
:
encoder
=
StreamingEncoder
(
model
,
required_cache_size
,
args
.
beam_size
,
transformer
)
encoder
.
eval
()
# begin to export encoder
chunk_xs
=
torch
.
randn
(
batch_size
,
audio_len
,
feature_size
,
dtype
=
torch
.
float32
)
chunk_lens
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int32
)
*
audio_len
offset
=
torch
.
arange
(
0
,
batch_size
).
unsqueeze
(
1
)
# (elayers, b, head, cache_t1, d_k * 2)
head
=
configs
[
"encoder_conf"
][
"attention_heads"
]
d_k
=
configs
[
"encoder_conf"
][
"output_size"
]
//
head
att_cache
=
torch
.
randn
(
batch_size
,
num_layers
,
head
,
required_cache_size
,
d_k
*
2
,
dtype
=
torch
.
float32
)
cnn_cache
=
torch
.
randn
(
batch_size
,
num_layers
,
output_size
,
cnn_module_kernel
,
dtype
=
torch
.
float32
)
cache_mask
=
torch
.
ones
(
batch_size
,
1
,
required_cache_size
,
dtype
=
torch
.
float32
)
input_names
=
[
'chunk_xs'
,
'chunk_lens'
,
'offset'
,
'att_cache'
,
'cnn_cache'
,
'cache_mask'
]
output_names
=
[
'log_probs'
,
'log_probs_idx'
,
'chunk_out'
,
'chunk_out_lens'
,
'r_offset'
,
'r_att_cache'
,
'r_cnn_cache'
,
'r_cache_mask'
]
input_tensors
=
(
chunk_xs
,
chunk_lens
,
offset
,
att_cache
,
cnn_cache
,
cache_mask
)
if
transformer
:
output_names
.
pop
(
6
)
all_names
=
input_names
+
output_names
dynamic_axes
=
{}
for
name
in
all_names
:
# only the first dimension is dynamic
# all other dimension is fixed
dynamic_axes
[
name
]
=
{
0
:
'B'
}
torch
.
onnx
.
export
(
encoder
,
input_tensors
,
encoder_onnx_path
,
export_params
=
True
,
opset_version
=
14
,
do_constant_folding
=
True
,
input_names
=
input_names
,
output_names
=
output_names
,
dynamic_axes
=
dynamic_axes
,
verbose
=
False
)
with
torch
.
no_grad
():
torch_outs
=
encoder
(
chunk_xs
,
chunk_lens
,
offset
,
att_cache
,
cnn_cache
,
cache_mask
)
if
transformer
:
torch_outs
=
list
(
torch_outs
).
pop
(
6
)
ort_session
=
onnxruntime
.
InferenceSession
(
encoder_onnx_path
,
providers
=
[
"CUDAExecutionProvider"
])
ort_inputs
=
{}
input_tensors
=
to_numpy
(
input_tensors
)
for
idx
,
name
in
enumerate
(
input_names
):
ort_inputs
[
name
]
=
input_tensors
[
idx
]
if
transformer
:
del
ort_inputs
[
'cnn_cache'
]
ort_outs
=
ort_session
.
run
(
None
,
ort_inputs
)
test
(
to_numpy
(
torch_outs
),
ort_outs
,
rtol
=
1e-03
,
atol
=
1e-05
)
logger
.
info
(
"export to onnx streaming encoder succeed!"
)
onnx_config
=
{
"subsampling_rate"
:
subsampling
,
"context"
:
context
,
"decoding_chunk_size"
:
decoding_chunk_size
,
"num_decoding_left_chunks"
:
num_decoding_left_chunks
,
"beam_size"
:
args
.
beam_size
,
"fp16"
:
args
.
fp16
,
"feat_size"
:
feature_size
,
"decoding_window"
:
decoding_window
,
"cnn_module_kernel_cache"
:
cnn_module_kernel
}
return
onnx_config
def
export_rescoring_decoder
(
model
,
configs
,
args
,
logger
,
decoder_onnx_path
,
decoder_fastertransformer
):
bz
,
seq_len
=
32
,
100
beam_size
=
args
.
beam_size
decoder
=
Decoder
(
model
.
decoder
,
model
.
ctc_weight
,
model
.
reverse_weight
,
beam_size
,
decoder_fastertransformer
)
decoder
.
eval
()
hyps_pad_sos_eos
=
torch
.
randint
(
low
=
3
,
high
=
1000
,
size
=
(
bz
,
beam_size
,
seq_len
))
hyps_lens_sos
=
torch
.
randint
(
low
=
3
,
high
=
seq_len
,
size
=
(
bz
,
beam_size
),
dtype
=
torch
.
int32
)
r_hyps_pad_sos_eos
=
torch
.
randint
(
low
=
3
,
high
=
1000
,
size
=
(
bz
,
beam_size
,
seq_len
))
output_size
=
configs
[
"encoder_conf"
][
"output_size"
]
encoder_out
=
torch
.
randn
(
bz
,
seq_len
,
output_size
,
dtype
=
torch
.
float32
)
encoder_out_lens
=
torch
.
randint
(
low
=
3
,
high
=
seq_len
,
size
=
(
bz
,),
dtype
=
torch
.
int32
)
ctc_score
=
torch
.
randn
(
bz
,
beam_size
,
dtype
=
torch
.
float32
)
input_names
=
[
'encoder_out'
,
'encoder_out_lens'
,
'hyps_pad_sos_eos'
,
'hyps_lens_sos'
,
'r_hyps_pad_sos_eos'
,
'ctc_score'
]
output_names
=
[
'best_index'
]
if
decoder_fastertransformer
:
output_names
.
insert
(
0
,
'decoder_out'
)
torch
.
onnx
.
export
(
decoder
,
(
encoder_out
,
encoder_out_lens
,
hyps_pad_sos_eos
,
hyps_lens_sos
,
r_hyps_pad_sos_eos
,
ctc_score
),
decoder_onnx_path
,
export_params
=
True
,
opset_version
=
13
,
do_constant_folding
=
True
,
input_names
=
input_names
,
output_names
=
output_names
,
dynamic_axes
=
{
'encoder_out'
:
{
0
:
'B'
,
1
:
'T'
},
'encoder_out_lens'
:
{
0
:
'B'
},
'hyps_pad_sos_eos'
:
{
0
:
'B'
,
2
:
'T2'
},
'hyps_lens_sos'
:
{
0
:
'B'
},
'r_hyps_pad_sos_eos'
:
{
0
:
'B'
,
2
:
'T2'
},
'ctc_score'
:
{
0
:
'B'
},
'best_index'
:
{
0
:
'B'
},
},
verbose
=
False
)
with
torch
.
no_grad
():
o0
=
decoder
(
encoder_out
,
encoder_out_lens
,
hyps_pad_sos_eos
,
hyps_lens_sos
,
r_hyps_pad_sos_eos
,
ctc_score
)
providers
=
[
"CUDAExecutionProvider"
]
ort_session
=
onnxruntime
.
InferenceSession
(
decoder_onnx_path
,
providers
=
providers
)
input_tensors
=
[
encoder_out
,
encoder_out_lens
,
hyps_pad_sos_eos
,
hyps_lens_sos
,
r_hyps_pad_sos_eos
,
ctc_score
]
ort_inputs
=
{}
input_tensors
=
to_numpy
(
input_tensors
)
for
idx
,
name
in
enumerate
(
input_names
):
ort_inputs
[
name
]
=
input_tensors
[
idx
]
# if model.reverse weight == 0,
# the r_hyps_pad will be removed
# from the onnx decoder since it doen't play any role
if
model
.
reverse_weight
==
0
:
del
ort_inputs
[
'r_hyps_pad_sos_eos'
]
ort_outs
=
ort_session
.
run
(
None
,
ort_inputs
)
# check decoder output
if
decoder_fastertransformer
:
test
(
to_numpy
(
o0
),
ort_outs
,
rtol
=
1e-03
,
atol
=
1e-05
)
else
:
test
(
to_numpy
([
o0
]),
ort_outs
,
rtol
=
1e-03
,
atol
=
1e-05
)
logger
.
info
(
"export to onnx decoder succeed!"
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'export x86_gpu model'
)
parser
.
add_argument
(
'--config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--checkpoint'
,
required
=
True
,
help
=
'checkpoint model'
)
parser
.
add_argument
(
'--cmvn_file'
,
required
=
False
,
default
=
''
,
type
=
str
,
help
=
'global_cmvn file, default path is in config file'
)
parser
.
add_argument
(
'--reverse_weight'
,
default
=-
1.0
,
type
=
float
,
required
=
False
,
help
=
'reverse weight for bitransformer,'
+
'default value is in config file'
)
parser
.
add_argument
(
'--ctc_weight'
,
default
=-
1.0
,
type
=
float
,
required
=
False
,
help
=
'ctc weight, default value is in config file'
)
parser
.
add_argument
(
'--beam_size'
,
default
=
10
,
type
=
int
,
required
=
False
,
help
=
"beam size would be ctc output size"
)
parser
.
add_argument
(
'--output_onnx_dir'
,
default
=
"onnx_model"
,
help
=
'output onnx encoder and decoder directory'
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'whether to export fp16 model, default false'
)
# arguments for streaming encoder
parser
.
add_argument
(
'--streaming'
,
action
=
'store_true'
,
help
=
"whether to export streaming encoder, default false"
)
parser
.
add_argument
(
'--decoding_chunk_size'
,
default
=
16
,
type
=
int
,
required
=
False
,
help
=
'the decoding chunk size, <=0 is not supported'
)
parser
.
add_argument
(
'--num_decoding_left_chunks'
,
default
=
5
,
type
=
int
,
required
=
False
,
help
=
"number of left chunks, <= 0 is not supported"
)
parser
.
add_argument
(
'--decoder_fastertransformer'
,
action
=
'store_true'
,
help
=
'return decoder_out and best_index for ft'
)
args
=
parser
.
parse_args
()
torch
.
manual_seed
(
0
)
torch
.
set_printoptions
(
precision
=
10
)
with
open
(
args
.
config
,
'r'
)
as
fin
:
configs
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
FullLoader
)
if
args
.
cmvn_file
and
os
.
path
.
exists
(
args
.
cmvn_file
):
configs
[
'cmvn_file'
]
=
args
.
cmvn_file
if
args
.
reverse_weight
!=
-
1.0
and
'reverse_weight'
in
configs
[
'model_conf'
]:
configs
[
'model_conf'
][
'reverse_weight'
]
=
args
.
reverse_weight
print
(
"Update reverse weight to"
,
args
.
reverse_weight
)
if
args
.
ctc_weight
!=
-
1
:
print
(
"Update ctc weight to "
,
args
.
ctc_weight
)
configs
[
'model_conf'
][
'ctc_weight'
]
=
args
.
ctc_weight
configs
[
"encoder_conf"
][
"use_dynamic_chunk"
]
=
False
model
=
init_model
(
configs
)
load_checkpoint
(
model
,
args
.
checkpoint
)
model
.
eval
()
if
not
os
.
path
.
exists
(
args
.
output_onnx_dir
):
os
.
mkdir
(
args
.
output_onnx_dir
)
encoder_onnx_path
=
os
.
path
.
join
(
args
.
output_onnx_dir
,
'encoder.onnx'
)
export_enc_func
=
None
if
args
.
streaming
:
assert
args
.
decoding_chunk_size
>
0
assert
args
.
num_decoding_left_chunks
>
0
export_enc_func
=
export_online_encoder
else
:
export_enc_func
=
export_offline_encoder
onnx_config
=
export_enc_func
(
model
,
configs
,
args
,
logger
,
encoder_onnx_path
)
decoder_onnx_path
=
os
.
path
.
join
(
args
.
output_onnx_dir
,
'decoder.onnx'
)
export_rescoring_decoder
(
model
,
configs
,
args
,
logger
,
decoder_onnx_path
,
args
.
decoder_fastertransformer
)
if
args
.
fp16
:
try
:
import
onnxmltools
from
onnxmltools.utils.float16_converter
import
convert_float_to_float16
except
ImportError
:
print
(
'Please install onnxmltools!'
)
sys
.
exit
(
1
)
encoder_onnx_model
=
onnxmltools
.
utils
.
load_model
(
encoder_onnx_path
)
encoder_onnx_model
=
convert_float_to_float16
(
encoder_onnx_model
)
encoder_onnx_path
=
os
.
path
.
join
(
args
.
output_onnx_dir
,
'encoder_fp16.onnx'
)
onnxmltools
.
utils
.
save_model
(
encoder_onnx_model
,
encoder_onnx_path
)
decoder_onnx_model
=
onnxmltools
.
utils
.
load_model
(
decoder_onnx_path
)
decoder_onnx_model
=
convert_float_to_float16
(
decoder_onnx_model
)
decoder_onnx_path
=
os
.
path
.
join
(
args
.
output_onnx_dir
,
'decoder_fp16.onnx'
)
onnxmltools
.
utils
.
save_model
(
decoder_onnx_model
,
decoder_onnx_path
)
# dump configurations
config_dir
=
os
.
path
.
join
(
args
.
output_onnx_dir
,
"config.yaml"
)
with
open
(
config_dir
,
"w"
)
as
out
:
yaml
.
dump
(
onnx_config
,
out
)
examples/aishell/s0/wenet/bin/recognize.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
argparse
import
copy
import
logging
import
os
import
sys
import
torch
import
yaml
from
torch.utils.data
import
DataLoader
from
wenet.dataset.dataset
import
Dataset
from
wenet.utils.checkpoint
import
load_checkpoint
from
wenet.utils.file_utils
import
read_symbol_table
,
read_non_lang_symbols
from
wenet.utils.config
import
override_config
from
wenet.utils.init_model
import
init_model
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'recognize with your model'
)
parser
.
add_argument
(
'--config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--test_data'
,
required
=
True
,
help
=
'test data file'
)
parser
.
add_argument
(
'--data_type'
,
default
=
'raw'
,
choices
=
[
'raw'
,
'shard'
],
help
=
'train and cv data type'
)
parser
.
add_argument
(
'--gpu'
,
type
=
int
,
default
=-
1
,
help
=
'gpu id for this rank, -1 for cpu'
)
parser
.
add_argument
(
'--checkpoint'
,
required
=
True
,
help
=
'checkpoint model'
)
parser
.
add_argument
(
'--dict'
,
required
=
True
,
help
=
'dict file'
)
parser
.
add_argument
(
"--non_lang_syms"
,
help
=
"non-linguistic symbol file. One symbol per line."
)
parser
.
add_argument
(
'--beam_size'
,
type
=
int
,
default
=
10
,
help
=
'beam size for search'
)
parser
.
add_argument
(
'--penalty'
,
type
=
float
,
default
=
0.0
,
help
=
'length penalty'
)
parser
.
add_argument
(
'--result_file'
,
required
=
True
,
help
=
'asr result file'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
16
,
help
=
'asr result file'
)
parser
.
add_argument
(
'--mode'
,
choices
=
[
'attention'
,
'ctc_greedy_search'
,
'ctc_prefix_beam_search'
,
'attention_rescoring'
,
'rnnt_greedy_search'
,
'rnnt_beam_search'
,
'rnnt_beam_attn_rescoring'
,
'ctc_beam_td_attn_rescoring'
,
'hlg_onebest'
,
'hlg_rescore'
],
default
=
'attention'
,
help
=
'decoding mode'
)
parser
.
add_argument
(
'--search_ctc_weight'
,
type
=
float
,
default
=
1.0
,
help
=
'ctc weight for nbest generation'
)
parser
.
add_argument
(
'--search_transducer_weight'
,
type
=
float
,
default
=
0.0
,
help
=
'transducer weight for nbest generation'
)
parser
.
add_argument
(
'--ctc_weight'
,
type
=
float
,
default
=
0.0
,
help
=
'ctc weight for rescoring weight in
\
attention rescoring decode mode
\
ctc weight for rescoring weight in
\
transducer attention rescore decode mode'
)
parser
.
add_argument
(
'--transducer_weight'
,
type
=
float
,
default
=
0.0
,
help
=
'transducer weight for rescoring weight in transducer
\
attention rescore mode'
)
parser
.
add_argument
(
'--attn_weight'
,
type
=
float
,
default
=
0.0
,
help
=
'attention weight for rescoring weight in transducer
\
attention rescore mode'
)
parser
.
add_argument
(
'--decoding_chunk_size'
,
type
=
int
,
default
=-
1
,
help
=
'''decoding chunk size,
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here'''
)
parser
.
add_argument
(
'--num_decoding_left_chunks'
,
type
=
int
,
default
=-
1
,
help
=
'number of left chunks for decoding'
)
parser
.
add_argument
(
'--simulate_streaming'
,
action
=
'store_true'
,
help
=
'simulate streaming inference'
)
parser
.
add_argument
(
'--reverse_weight'
,
type
=
float
,
default
=
0.0
,
help
=
'''right to left weight for attention rescoring
decode mode'''
)
parser
.
add_argument
(
'--bpe_model'
,
default
=
None
,
type
=
str
,
help
=
'bpe model for english part'
)
parser
.
add_argument
(
'--override_config'
,
action
=
'append'
,
default
=
[],
help
=
"override yaml config"
)
parser
.
add_argument
(
'--connect_symbol'
,
default
=
''
,
type
=
str
,
help
=
'used to connect the output characters'
)
parser
.
add_argument
(
'--word'
,
default
=
''
,
type
=
str
,
help
=
'word file, only used for hlg decode'
)
parser
.
add_argument
(
'--hlg'
,
default
=
''
,
type
=
str
,
help
=
'hlg file, only used for hlg decode'
)
parser
.
add_argument
(
'--lm_scale'
,
type
=
float
,
default
=
0.0
,
help
=
'lm scale for hlg attention rescore decode'
)
parser
.
add_argument
(
'--decoder_scale'
,
type
=
float
,
default
=
0.0
,
help
=
'lm scale for hlg attention rescore decode'
)
parser
.
add_argument
(
'--r_decoder_scale'
,
type
=
float
,
default
=
0.0
,
help
=
'lm scale for hlg attention rescore decode'
)
args
=
parser
.
parse_args
()
print
(
args
)
return
args
def
main
():
args
=
get_args
()
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
format
=
'%(asctime)s %(levelname)s %(message)s'
)
os
.
environ
[
'HIP_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu
)
if
args
.
mode
in
[
'ctc_prefix_beam_search'
,
'attention_rescoring'
]
and
args
.
batch_size
>
1
:
logging
.
fatal
(
'decoding mode {} must be running with batch_size == 1'
.
format
(
args
.
mode
))
sys
.
exit
(
1
)
with
open
(
args
.
config
,
'r'
)
as
fin
:
configs
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
FullLoader
)
if
len
(
args
.
override_config
)
>
0
:
configs
=
override_config
(
configs
,
args
.
override_config
)
symbol_table
=
read_symbol_table
(
args
.
dict
)
test_conf
=
copy
.
deepcopy
(
configs
[
'dataset_conf'
])
test_conf
[
'filter_conf'
][
'max_length'
]
=
102400
test_conf
[
'filter_conf'
][
'min_length'
]
=
0
test_conf
[
'filter_conf'
][
'token_max_length'
]
=
102400
test_conf
[
'filter_conf'
][
'token_min_length'
]
=
0
test_conf
[
'filter_conf'
][
'max_output_input_ratio'
]
=
102400
test_conf
[
'filter_conf'
][
'min_output_input_ratio'
]
=
0
test_conf
[
'speed_perturb'
]
=
False
test_conf
[
'spec_aug'
]
=
False
test_conf
[
'spec_sub'
]
=
False
test_conf
[
'spec_trim'
]
=
False
test_conf
[
'shuffle'
]
=
False
test_conf
[
'sort'
]
=
False
if
'fbank_conf'
in
test_conf
:
test_conf
[
'fbank_conf'
][
'dither'
]
=
0.0
elif
'mfcc_conf'
in
test_conf
:
test_conf
[
'mfcc_conf'
][
'dither'
]
=
0.0
test_conf
[
'batch_conf'
][
'batch_type'
]
=
"static"
test_conf
[
'batch_conf'
][
'batch_size'
]
=
args
.
batch_size
non_lang_syms
=
read_non_lang_symbols
(
args
.
non_lang_syms
)
test_dataset
=
Dataset
(
args
.
data_type
,
args
.
test_data
,
symbol_table
,
test_conf
,
args
.
bpe_model
,
non_lang_syms
,
partition
=
False
)
test_data_loader
=
DataLoader
(
test_dataset
,
batch_size
=
None
,
num_workers
=
8
,
pin_memory
=
True
)
# Init asr model from configs
model
=
init_model
(
configs
)
#print('############################')
# Load dict
char_dict
=
{
v
:
k
for
k
,
v
in
symbol_table
.
items
()}
eos
=
len
(
char_dict
)
-
1
load_checkpoint
(
model
,
args
.
checkpoint
)
use_cuda
=
args
.
gpu
>=
0
and
torch
.
cuda
.
is_available
()
device
=
torch
.
device
(
'cuda'
if
use_cuda
else
'cpu'
)
model
=
model
.
to
(
device
)
#print('model to device############')
model
.
eval
()
with
torch
.
no_grad
(),
open
(
args
.
result_file
,
'w'
)
as
fout
:
for
batch_idx
,
batch
in
enumerate
(
test_data_loader
):
keys
,
feats
,
target
,
feats_lengths
,
target_lengths
=
batch
feats
=
feats
.
to
(
device
)
target
=
target
.
to
(
device
)
feats_lengths
=
feats_lengths
.
to
(
device
)
target_lengths
=
target_lengths
.
to
(
device
)
if
args
.
mode
==
'attention'
:
hyps
,
_
=
model
.
recognize
(
feats
,
feats_lengths
,
beam_size
=
args
.
beam_size
,
decoding_chunk_size
=
args
.
decoding_chunk_size
,
num_decoding_left_chunks
=
args
.
num_decoding_left_chunks
,
simulate_streaming
=
args
.
simulate_streaming
)
hyps
=
[
hyp
.
tolist
()
for
hyp
in
hyps
]
elif
args
.
mode
==
'ctc_greedy_search'
:
hyps
,
_
=
model
.
ctc_greedy_search
(
feats
,
feats_lengths
,
decoding_chunk_size
=
args
.
decoding_chunk_size
,
num_decoding_left_chunks
=
args
.
num_decoding_left_chunks
,
simulate_streaming
=
args
.
simulate_streaming
)
elif
args
.
mode
==
'rnnt_greedy_search'
:
assert
(
feats
.
size
(
0
)
==
1
)
assert
'predictor'
in
configs
hyps
=
model
.
greedy_search
(
feats
,
feats_lengths
,
decoding_chunk_size
=
args
.
decoding_chunk_size
,
num_decoding_left_chunks
=
args
.
num_decoding_left_chunks
,
simulate_streaming
=
args
.
simulate_streaming
)
elif
args
.
mode
==
'rnnt_beam_search'
:
assert
(
feats
.
size
(
0
)
==
1
)
assert
'predictor'
in
configs
hyps
=
model
.
beam_search
(
feats
,
feats_lengths
,
decoding_chunk_size
=
args
.
decoding_chunk_size
,
beam_size
=
args
.
beam_size
,
num_decoding_left_chunks
=
args
.
num_decoding_left_chunks
,
simulate_streaming
=
args
.
simulate_streaming
,
ctc_weight
=
args
.
search_ctc_weight
,
transducer_weight
=
args
.
search_transducer_weight
)
elif
args
.
mode
==
'rnnt_beam_attn_rescoring'
:
assert
(
feats
.
size
(
0
)
==
1
)
assert
'predictor'
in
configs
hyps
=
model
.
transducer_attention_rescoring
(
feats
,
feats_lengths
,
decoding_chunk_size
=
args
.
decoding_chunk_size
,
beam_size
=
args
.
beam_size
,
num_decoding_left_chunks
=
args
.
num_decoding_left_chunks
,
simulate_streaming
=
args
.
simulate_streaming
,
ctc_weight
=
args
.
ctc_weight
,
transducer_weight
=
args
.
transducer_weight
,
attn_weight
=
args
.
attn_weight
,
reverse_weight
=
args
.
reverse_weight
,
search_ctc_weight
=
args
.
search_ctc_weight
,
search_transducer_weight
=
args
.
search_transducer_weight
)
elif
args
.
mode
==
'ctc_beam_td_attn_rescoring'
:
assert
(
feats
.
size
(
0
)
==
1
)
assert
'predictor'
in
configs
hyps
=
model
.
transducer_attention_rescoring
(
feats
,
feats_lengths
,
decoding_chunk_size
=
args
.
decoding_chunk_size
,
beam_size
=
args
.
beam_size
,
num_decoding_left_chunks
=
args
.
num_decoding_left_chunks
,
simulate_streaming
=
args
.
simulate_streaming
,
ctc_weight
=
args
.
ctc_weight
,
transducer_weight
=
args
.
transducer_weight
,
attn_weight
=
args
.
attn_weight
,
reverse_weight
=
args
.
reverse_weight
,
search_ctc_weight
=
args
.
search_ctc_weight
,
search_transducer_weight
=
args
.
search_transducer_weight
,
beam_search_type
=
'ctc'
)
# ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif
args
.
mode
==
'ctc_prefix_beam_search'
:
assert
(
feats
.
size
(
0
)
==
1
)
hyp
,
_
=
model
.
ctc_prefix_beam_search
(
feats
,
feats_lengths
,
args
.
beam_size
,
decoding_chunk_size
=
args
.
decoding_chunk_size
,
num_decoding_left_chunks
=
args
.
num_decoding_left_chunks
,
simulate_streaming
=
args
.
simulate_streaming
)
hyps
=
[
hyp
]
elif
args
.
mode
==
'attention_rescoring'
:
#print('11111111111 attention_resoring 1111111111111111')
assert
(
feats
.
size
(
0
)
==
1
)
hyp
,
source
=
model
.
attention_rescoring
(
feats
,
feats_lengths
,
args
.
beam_size
,
decoding_chunk_size
=
args
.
decoding_chunk_size
,
num_decoding_left_chunks
=
args
.
num_decoding_left_chunks
,
ctc_weight
=
args
.
ctc_weight
,
simulate_streaming
=
args
.
simulate_streaming
,
reverse_weight
=
args
.
reverse_weight
)
hyps
=
[
hyp
]
#print(hyps)
#print(source)
elif
args
.
mode
==
'hlg_onebest'
:
hyps
=
model
.
hlg_onebest
(
feats
,
feats_lengths
,
decoding_chunk_size
=
args
.
decoding_chunk_size
,
num_decoding_left_chunks
=
args
.
num_decoding_left_chunks
,
simulate_streaming
=
args
.
simulate_streaming
,
hlg
=
args
.
hlg
,
word
=
args
.
word
,
symbol_table
=
symbol_table
)
elif
args
.
mode
==
'hlg_rescore'
:
hyps
=
model
.
hlg_rescore
(
feats
,
feats_lengths
,
decoding_chunk_size
=
args
.
decoding_chunk_size
,
num_decoding_left_chunks
=
args
.
num_decoding_left_chunks
,
simulate_streaming
=
args
.
simulate_streaming
,
lm_scale
=
args
.
lm_scale
,
decoder_scale
=
args
.
decoder_scale
,
r_decoder_scale
=
args
.
r_decoder_scale
,
hlg
=
args
.
hlg
,
word
=
args
.
word
,
symbol_table
=
symbol_table
)
for
i
,
key
in
enumerate
(
keys
):
content
=
[]
for
w
in
hyps
[
i
]:
if
w
==
eos
:
break
content
.
append
(
char_dict
[
w
])
logging
.
info
(
'{} {}'
.
format
(
key
,
args
.
connect_symbol
.
join
(
content
)))
fout
.
write
(
'{} {}
\n
'
.
format
(
key
,
args
.
connect_symbol
.
join
(
content
)))
if
__name__
==
'__main__'
:
main
()
examples/aishell/s0/wenet/bin/recognize_onnx_gpu.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for testing exported onnx encoder and decoder from
export_onnx_gpu.py. The exported onnx models only support batch offline ASR inference.
It requires a python wrapped c++ ctc decoder.
Please install it by following:
https://github.com/Slyne/ctc_decoder.git
"""
from
__future__
import
print_function
import
argparse
import
copy
import
logging
import
os
import
sys
import
torch
import
yaml
from
torch.utils.data
import
DataLoader
from
wenet.dataset.dataset
import
Dataset
from
wenet.utils.common
import
IGNORE_ID
from
wenet.utils.file_utils
import
read_symbol_table
from
wenet.utils.config
import
override_config
import
onnxruntime
as
rt
import
multiprocessing
import
numpy
as
np
try
:
from
swig_decoders
import
map_batch
,
\
ctc_beam_search_decoder_batch
,
\
TrieVector
,
PathTrie
except
ImportError
:
print
(
'Please install ctc decoders first by refering to
\n
'
+
'https://github.com/Slyne/ctc_decoder.git'
)
sys
.
exit
(
1
)
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'recognize with your model'
)
parser
.
add_argument
(
'--config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--test_data'
,
required
=
True
,
help
=
'test data file'
)
parser
.
add_argument
(
'--data_type'
,
default
=
'raw'
,
choices
=
[
'raw'
,
'shard'
],
help
=
'train and cv data type'
)
parser
.
add_argument
(
'--gpu'
,
type
=
int
,
default
=-
1
,
help
=
'gpu id for this rank, -1 for cpu'
)
parser
.
add_argument
(
'--dict'
,
required
=
True
,
help
=
'dict file'
)
parser
.
add_argument
(
'--encoder_onnx'
,
required
=
True
,
help
=
'encoder onnx file'
)
parser
.
add_argument
(
'--decoder_onnx'
,
required
=
True
,
help
=
'decoder onnx file'
)
parser
.
add_argument
(
'--result_file'
,
required
=
True
,
help
=
'asr result file'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
32
,
help
=
'asr result file'
)
parser
.
add_argument
(
'--mode'
,
choices
=
[
'ctc_greedy_search'
,
'ctc_prefix_beam_search'
,
'attention_rescoring'
],
default
=
'attention_rescoring'
,
help
=
'decoding mode'
)
parser
.
add_argument
(
'--bpe_model'
,
default
=
None
,
type
=
str
,
help
=
'bpe model for english part'
)
parser
.
add_argument
(
'--override_config'
,
action
=
'append'
,
default
=
[],
help
=
"override yaml config"
)
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'whether to export fp16 model, default false'
)
args
=
parser
.
parse_args
()
print
(
args
)
return
args
def
main
():
args
=
get_args
()
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
format
=
'%(asctime)s %(levelname)s %(message)s'
)
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu
)
with
open
(
args
.
config
,
'r'
)
as
fin
:
configs
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
FullLoader
)
if
len
(
args
.
override_config
)
>
0
:
configs
=
override_config
(
configs
,
args
.
override_config
)
reverse_weight
=
configs
[
"model_conf"
].
get
(
"reverse_weight"
,
0.0
)
symbol_table
=
read_symbol_table
(
args
.
dict
)
test_conf
=
copy
.
deepcopy
(
configs
[
'dataset_conf'
])
test_conf
[
'filter_conf'
][
'max_length'
]
=
102400
test_conf
[
'filter_conf'
][
'min_length'
]
=
0
test_conf
[
'filter_conf'
][
'token_max_length'
]
=
102400
test_conf
[
'filter_conf'
][
'token_min_length'
]
=
0
test_conf
[
'filter_conf'
][
'max_output_input_ratio'
]
=
102400
test_conf
[
'filter_conf'
][
'min_output_input_ratio'
]
=
0
test_conf
[
'speed_perturb'
]
=
False
test_conf
[
'spec_aug'
]
=
False
test_conf
[
'spec_trim'
]
=
False
test_conf
[
'shuffle'
]
=
False
test_conf
[
'sort'
]
=
False
test_conf
[
'fbank_conf'
][
'dither'
]
=
0.0
test_conf
[
'batch_conf'
][
'batch_type'
]
=
"static"
test_conf
[
'batch_conf'
][
'batch_size'
]
=
args
.
batch_size
test_dataset
=
Dataset
(
args
.
data_type
,
args
.
test_data
,
symbol_table
,
test_conf
,
args
.
bpe_model
,
partition
=
False
)
test_data_loader
=
DataLoader
(
test_dataset
,
batch_size
=
None
,
num_workers
=
0
)
# Init asr model from configs
use_cuda
=
args
.
gpu
>=
0
and
torch
.
cuda
.
is_available
()
if
use_cuda
:
EP_list
=
[
'CUDAExecutionProvider'
,
'CPUExecutionProvider'
]
else
:
EP_list
=
[
'CPUExecutionProvider'
]
encoder_ort_session
=
rt
.
InferenceSession
(
args
.
encoder_onnx
,
providers
=
EP_list
)
decoder_ort_session
=
None
if
args
.
mode
==
"attention_rescoring"
:
decoder_ort_session
=
rt
.
InferenceSession
(
args
.
decoder_onnx
,
providers
=
EP_list
)
# Load dict
vocabulary
=
[]
char_dict
=
{}
with
open
(
args
.
dict
,
'r'
)
as
fin
:
for
line
in
fin
:
arr
=
line
.
strip
().
split
()
assert
len
(
arr
)
==
2
char_dict
[
int
(
arr
[
1
])]
=
arr
[
0
]
vocabulary
.
append
(
arr
[
0
])
eos
=
sos
=
len
(
char_dict
)
-
1
with
torch
.
no_grad
(),
open
(
args
.
result_file
,
'w'
)
as
fout
:
for
_
,
batch
in
enumerate
(
test_data_loader
):
keys
,
feats
,
_
,
feats_lengths
,
_
=
batch
feats
,
feats_lengths
=
feats
.
numpy
(),
feats_lengths
.
numpy
()
if
args
.
fp16
:
feats
=
feats
.
astype
(
np
.
float16
)
ort_inputs
=
{
encoder_ort_session
.
get_inputs
()[
0
].
name
:
feats
,
encoder_ort_session
.
get_inputs
()[
1
].
name
:
feats_lengths
}
ort_outs
=
encoder_ort_session
.
run
(
None
,
ort_inputs
)
encoder_out
,
encoder_out_lens
,
ctc_log_probs
,
\
beam_log_probs
,
beam_log_probs_idx
=
ort_outs
beam_size
=
beam_log_probs
.
shape
[
-
1
]
batch_size
=
beam_log_probs
.
shape
[
0
]
num_processes
=
min
(
multiprocessing
.
cpu_count
(),
batch_size
)
if
args
.
mode
==
'ctc_greedy_search'
:
if
beam_size
!=
1
:
log_probs_idx
=
beam_log_probs_idx
[:,
:,
0
]
batch_sents
=
[]
for
idx
,
seq
in
enumerate
(
log_probs_idx
):
batch_sents
.
append
(
seq
[
0
:
encoder_out_lens
[
idx
]].
tolist
())
hyps
=
map_batch
(
batch_sents
,
vocabulary
,
num_processes
,
True
,
0
)
elif
args
.
mode
in
(
'ctc_prefix_beam_search'
,
"attention_rescoring"
):
batch_log_probs_seq_list
=
beam_log_probs
.
tolist
()
batch_log_probs_idx_list
=
beam_log_probs_idx
.
tolist
()
batch_len_list
=
encoder_out_lens
.
tolist
()
batch_log_probs_seq
=
[]
batch_log_probs_ids
=
[]
batch_start
=
[]
# only effective in streaming deployment
batch_root
=
TrieVector
()
root_dict
=
{}
for
i
in
range
(
len
(
batch_len_list
)):
num_sent
=
batch_len_list
[
i
]
batch_log_probs_seq
.
append
(
batch_log_probs_seq_list
[
i
][
0
:
num_sent
])
batch_log_probs_ids
.
append
(
batch_log_probs_idx_list
[
i
][
0
:
num_sent
])
root_dict
[
i
]
=
PathTrie
()
batch_root
.
append
(
root_dict
[
i
])
batch_start
.
append
(
True
)
score_hyps
=
ctc_beam_search_decoder_batch
(
batch_log_probs_seq
,
batch_log_probs_ids
,
batch_root
,
batch_start
,
beam_size
,
num_processes
,
0
,
-
2
,
0.99999
)
if
args
.
mode
==
'ctc_prefix_beam_search'
:
hyps
=
[]
for
cand_hyps
in
score_hyps
:
hyps
.
append
(
cand_hyps
[
0
][
1
])
hyps
=
map_batch
(
hyps
,
vocabulary
,
num_processes
,
False
,
0
)
if
args
.
mode
==
'attention_rescoring'
:
ctc_score
,
all_hyps
=
[],
[]
max_len
=
0
for
hyps
in
score_hyps
:
cur_len
=
len
(
hyps
)
if
len
(
hyps
)
<
beam_size
:
hyps
+=
(
beam_size
-
cur_len
)
*
[(
-
float
(
"INF"
),
(
0
,))]
cur_ctc_score
=
[]
for
hyp
in
hyps
:
cur_ctc_score
.
append
(
hyp
[
0
])
all_hyps
.
append
(
list
(
hyp
[
1
]))
if
len
(
hyp
[
1
])
>
max_len
:
max_len
=
len
(
hyp
[
1
])
ctc_score
.
append
(
cur_ctc_score
)
if
args
.
fp16
:
ctc_score
=
np
.
array
(
ctc_score
,
dtype
=
np
.
float16
)
else
:
ctc_score
=
np
.
array
(
ctc_score
,
dtype
=
np
.
float32
)
hyps_pad_sos_eos
=
np
.
ones
(
(
batch_size
,
beam_size
,
max_len
+
2
),
dtype
=
np
.
int64
)
*
IGNORE_ID
r_hyps_pad_sos_eos
=
np
.
ones
(
(
batch_size
,
beam_size
,
max_len
+
2
),
dtype
=
np
.
int64
)
*
IGNORE_ID
hyps_lens_sos
=
np
.
ones
((
batch_size
,
beam_size
),
dtype
=
np
.
int32
)
k
=
0
for
i
in
range
(
batch_size
):
for
j
in
range
(
beam_size
):
cand
=
all_hyps
[
k
]
l
=
len
(
cand
)
+
2
hyps_pad_sos_eos
[
i
][
j
][
0
:
l
]
=
[
sos
]
+
cand
+
[
eos
]
r_hyps_pad_sos_eos
[
i
][
j
][
0
:
l
]
=
[
sos
]
+
cand
[::
-
1
]
+
[
eos
]
hyps_lens_sos
[
i
][
j
]
=
len
(
cand
)
+
1
k
+=
1
decoder_ort_inputs
=
{
decoder_ort_session
.
get_inputs
()[
0
].
name
:
encoder_out
,
decoder_ort_session
.
get_inputs
()[
1
].
name
:
encoder_out_lens
,
decoder_ort_session
.
get_inputs
()[
2
].
name
:
hyps_pad_sos_eos
,
decoder_ort_session
.
get_inputs
()[
3
].
name
:
hyps_lens_sos
,
decoder_ort_session
.
get_inputs
()[
-
1
].
name
:
ctc_score
}
if
reverse_weight
>
0
:
r_hyps_pad_sos_eos_name
=
decoder_ort_session
.
get_inputs
()[
4
].
name
decoder_ort_inputs
[
r_hyps_pad_sos_eos_name
]
=
r_hyps_pad_sos_eos
best_index
=
decoder_ort_session
.
run
(
None
,
decoder_ort_inputs
)[
0
]
best_sents
=
[]
k
=
0
for
idx
in
best_index
:
cur_best_sent
=
all_hyps
[
k
:
k
+
beam_size
][
idx
]
best_sents
.
append
(
cur_best_sent
)
k
+=
beam_size
hyps
=
map_batch
(
best_sents
,
vocabulary
,
num_processes
)
for
i
,
key
in
enumerate
(
keys
):
content
=
hyps
[
i
]
logging
.
info
(
'{} {}'
.
format
(
key
,
content
))
fout
.
write
(
'{} {}
\n
'
.
format
(
key
,
content
))
if
__name__
==
'__main__'
:
main
()
examples/aishell/s0/wenet/bin/train.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
argparse
import
copy
import
logging
import
os
import
time
import
torch
import
torch.distributed
as
dist
import
torch.optim
as
optim
import
yaml
from
tensorboardX
import
SummaryWriter
from
torch.utils.data
import
DataLoader
from
wenet.dataset.dataset
import
Dataset
from
wenet.utils.checkpoint
import
(
load_checkpoint
,
save_checkpoint
,
load_trained_modules
)
from
wenet.utils.executor
import
Executor
from
wenet.utils.file_utils
import
read_symbol_table
,
read_non_lang_symbols
from
wenet.utils.scheduler
import
WarmupLR
,
NoamHoldAnnealing
from
wenet.utils.config
import
override_config
from
wenet.utils.init_model
import
init_model
from
wenet.utils.global_vars
import
get_global_steps
,
get_num_trained_samples
from
wenet.utils.compute_acc
import
compute_char_acc
def
write_pid_file
(
pid_file_path
):
'''Write pid file for watching the process later.
In each round of case, we will write the current pid in the same path.
'''
if
os
.
path
.
exists
(
pid_file_path
):
os
.
remove
(
pid_file_path
)
file_d
=
open
(
pid_file_path
,
"w"
)
file_d
.
write
(
"%s
\n
"
%
os
.
getpid
())
file_d
.
close
()
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'training your network'
)
parser
.
add_argument
(
'--config'
,
required
=
True
,
help
=
'config file'
)
parser
.
add_argument
(
'--data_type'
,
default
=
'raw'
,
choices
=
[
'raw'
,
'shard'
],
help
=
'train and cv data type'
)
parser
.
add_argument
(
'--train_data'
,
required
=
True
,
help
=
'train data file'
)
parser
.
add_argument
(
'--cv_data'
,
required
=
True
,
help
=
'cv data file'
)
parser
.
add_argument
(
'--gpu'
,
type
=
int
,
default
=-
1
,
help
=
'gpu id for this local rank, -1 for cpu'
)
parser
.
add_argument
(
'--model_dir'
,
required
=
True
,
help
=
'save model dir'
)
parser
.
add_argument
(
'--checkpoint'
,
help
=
'checkpoint model'
)
parser
.
add_argument
(
'--tensorboard_dir'
,
default
=
'tensorboard'
,
help
=
'tensorboard log dir'
)
parser
.
add_argument
(
'--ddp.rank'
,
dest
=
'rank'
,
default
=
0
,
type
=
int
,
help
=
'global rank for distributed training'
)
parser
.
add_argument
(
'--ddp.world_size'
,
dest
=
'world_size'
,
default
=-
1
,
type
=
int
,
help
=
'''number of total processes/gpus for
distributed training'''
)
parser
.
add_argument
(
'--ddp.dist_backend'
,
dest
=
'dist_backend'
,
default
=
'nccl'
,
choices
=
[
'nccl'
,
'gloo'
],
help
=
'distributed backend'
)
parser
.
add_argument
(
'--ddp.init_method'
,
dest
=
'init_method'
,
default
=
None
,
help
=
'ddp init method'
)
parser
.
add_argument
(
'--num_workers'
,
default
=
0
,
type
=
int
,
help
=
'num of subprocess workers for reading'
)
parser
.
add_argument
(
'--pin_memory'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use pinned memory buffers used for reading'
)
parser
.
add_argument
(
'--use_amp'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use automatic mixed precision training'
)
parser
.
add_argument
(
'--fp16_grad_sync'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use fp16 gradient sync for ddp'
)
parser
.
add_argument
(
'--cmvn'
,
default
=
None
,
help
=
'global cmvn file'
)
parser
.
add_argument
(
'--symbol_table'
,
required
=
True
,
help
=
'model unit symbol table for training'
)
parser
.
add_argument
(
"--non_lang_syms"
,
help
=
"non-linguistic symbol file. One symbol per line."
)
parser
.
add_argument
(
'--prefetch'
,
default
=
100
,
type
=
int
,
help
=
'prefetch number'
)
parser
.
add_argument
(
'--bpe_model'
,
default
=
None
,
type
=
str
,
help
=
'bpe model for english part'
)
parser
.
add_argument
(
'--override_config'
,
action
=
'append'
,
default
=
[],
help
=
"override yaml config"
)
parser
.
add_argument
(
"--enc_init"
,
default
=
None
,
type
=
str
,
help
=
"Pre-trained model to initialize encoder"
)
parser
.
add_argument
(
"--enc_init_mods"
,
default
=
"encoder."
,
type
=
lambda
s
:
[
str
(
mod
)
for
mod
in
s
.
split
(
","
)
if
s
!=
""
],
help
=
"List of encoder modules
\
to initialize ,separated by a comma"
)
parser
.
add_argument
(
'--val_ref_file'
,
dest
=
'val_ref_file'
,
default
=
'data/test/text'
,
help
=
'validation ref file'
)
parser
.
add_argument
(
'--val_hyp_file'
,
dest
=
'val_hyp_file'
,
default
=
'exp/conformer/test_attention_rescoring/text'
,
help
=
'validation hyp file'
)
parser
.
add_argument
(
'--log_dir'
,
type
=
str
,
default
=
'/data/flagperf/training/result/'
,
help
=
'Log directory in container.'
)
args
=
parser
.
parse_args
()
return
args
def
main
():
args
=
get_args
()
if
args
.
rank
==
0
:
write_pid_file
(
args
.
log_dir
)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
format
=
'%(asctime)s %(levelname)s %(message)s'
)
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu
)
# Set random seed
torch
.
manual_seed
(
777
)
with
open
(
args
.
config
,
'r'
)
as
fin
:
configs
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
FullLoader
)
if
len
(
args
.
override_config
)
>
0
:
configs
=
override_config
(
configs
,
args
.
override_config
)
distributed
=
args
.
world_size
>
1
if
distributed
:
logging
.
info
(
'training on multiple gpus, this gpu {}'
.
format
(
args
.
gpu
))
dist
.
init_process_group
(
args
.
dist_backend
,
init_method
=
args
.
init_method
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
symbol_table
=
read_symbol_table
(
args
.
symbol_table
)
train_conf
=
configs
[
'dataset_conf'
]
cv_conf
=
copy
.
deepcopy
(
train_conf
)
cv_conf
[
'speed_perturb'
]
=
False
cv_conf
[
'spec_aug'
]
=
False
cv_conf
[
'spec_sub'
]
=
False
cv_conf
[
'spec_trim'
]
=
False
cv_conf
[
'shuffle'
]
=
False
non_lang_syms
=
read_non_lang_symbols
(
args
.
non_lang_syms
)
train_dataset
=
Dataset
(
args
.
data_type
,
args
.
train_data
,
symbol_table
,
train_conf
,
args
.
bpe_model
,
non_lang_syms
,
True
)
cv_dataset
=
Dataset
(
args
.
data_type
,
args
.
cv_data
,
symbol_table
,
cv_conf
,
args
.
bpe_model
,
non_lang_syms
,
partition
=
False
)
train_data_loader
=
DataLoader
(
train_dataset
,
batch_size
=
None
,
pin_memory
=
args
.
pin_memory
,
num_workers
=
args
.
num_workers
,
prefetch_factor
=
args
.
prefetch
)
cv_data_loader
=
DataLoader
(
cv_dataset
,
batch_size
=
None
,
pin_memory
=
args
.
pin_memory
,
num_workers
=
args
.
num_workers
,
prefetch_factor
=
args
.
prefetch
)
if
'fbank_conf'
in
configs
[
'dataset_conf'
]:
input_dim
=
configs
[
'dataset_conf'
][
'fbank_conf'
][
'num_mel_bins'
]
else
:
input_dim
=
configs
[
'dataset_conf'
][
'mfcc_conf'
][
'num_mel_bins'
]
vocab_size
=
len
(
symbol_table
)
# Save configs to model_dir/train.yaml for inference and export
configs
[
'input_dim'
]
=
input_dim
configs
[
'output_dim'
]
=
vocab_size
configs
[
'cmvn_file'
]
=
args
.
cmvn
configs
[
'is_json_cmvn'
]
=
True
if
args
.
rank
==
0
:
saved_config_path
=
os
.
path
.
join
(
args
.
model_dir
,
'train.yaml'
)
with
open
(
saved_config_path
,
'w'
)
as
fout
:
data
=
yaml
.
dump
(
configs
)
fout
.
write
(
data
)
# Init asr model from configs
model
=
init_model
(
configs
)
print
(
model
)
num_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
print
(
'the number of model params: {:,d}'
.
format
(
num_params
))
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
if
args
.
rank
==
0
:
script_model
=
torch
.
jit
.
script
(
model
)
script_model
.
save
(
os
.
path
.
join
(
args
.
model_dir
,
'init.zip'
))
executor
=
Executor
()
# If specify checkpoint, load some info from checkpoint
if
args
.
checkpoint
is
not
None
:
infos
=
load_checkpoint
(
model
,
args
.
checkpoint
)
elif
args
.
enc_init
is
not
None
:
logging
.
info
(
'load pretrained encoders: {}'
.
format
(
args
.
enc_init
))
infos
=
load_trained_modules
(
model
,
args
)
else
:
infos
=
{}
start_epoch
=
infos
.
get
(
'epoch'
,
-
1
)
+
1
cv_loss
=
infos
.
get
(
'cv_loss'
,
0.0
)
step
=
infos
.
get
(
'step'
,
-
1
)
num_epochs
=
configs
.
get
(
'max_epoch'
,
100
)
model_dir
=
args
.
model_dir
writer
=
None
if
args
.
rank
==
0
:
os
.
makedirs
(
model_dir
,
exist_ok
=
True
)
exp_id
=
os
.
path
.
basename
(
model_dir
)
#writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id))
if
distributed
:
assert
(
torch
.
cuda
.
is_available
())
# cuda model is required for nn.parallel.DistributedDataParallel
model
.
cuda
()
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
,
find_unused_parameters
=
False
)
device
=
torch
.
device
(
"cuda"
)
if
args
.
fp16_grad_sync
:
from
torch.distributed.algorithms.ddp_comm_hooks
import
(
default
as
comm_hooks
,
)
model
.
register_comm_hook
(
state
=
None
,
hook
=
comm_hooks
.
fp16_compress_hook
)
else
:
use_cuda
=
args
.
gpu
>=
0
and
torch
.
cuda
.
is_available
()
device
=
torch
.
device
(
'cuda'
if
use_cuda
else
'cpu'
)
model
=
model
.
to
(
device
)
if
configs
[
'optim'
]
==
'adam'
:
optimizer
=
optim
.
Adam
(
model
.
parameters
(),
**
configs
[
'optim_conf'
])
elif
configs
[
'optim'
]
==
'adamw'
:
optimizer
=
optim
.
AdamW
(
model
.
parameters
(),
**
configs
[
'optim_conf'
])
else
:
raise
ValueError
(
"unknown optimizer: "
+
configs
[
'optim'
])
if
configs
[
'scheduler'
]
==
'warmuplr'
:
scheduler
=
WarmupLR
(
optimizer
,
**
configs
[
'scheduler_conf'
])
elif
configs
[
'scheduler'
]
==
'NoamHoldAnnealing'
:
scheduler
=
NoamHoldAnnealing
(
optimizer
,
**
configs
[
'scheduler_conf'
])
else
:
raise
ValueError
(
"unknown scheduler: "
+
configs
[
'scheduler'
])
final_epoch
=
None
target_acc
=
93.0
final_acc
=
0
training_only
=
0
configs
[
'rank'
]
=
args
.
rank
configs
[
'is_distributed'
]
=
distributed
configs
[
'use_amp'
]
=
args
.
use_amp
if
start_epoch
==
0
and
args
.
rank
==
0
:
save_model_path
=
os
.
path
.
join
(
model_dir
,
'init.pt'
)
save_checkpoint
(
model
,
save_model_path
)
# Start training loop
executor
.
step
=
step
scheduler
.
set_step
(
step
)
# used for pytorch amp mixed precision training
scaler
=
None
if
args
.
use_amp
:
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
training_start
=
time
.
time
()
for
epoch
in
range
(
start_epoch
,
num_epochs
):
start
=
time
.
time
()
train_dataset
.
set_epoch
(
epoch
)
configs
[
'epoch'
]
=
epoch
lr
=
optimizer
.
param_groups
[
0
][
'lr'
]
logging
.
info
(
'Epoch {} TRAIN info lr {}'
.
format
(
epoch
,
lr
))
executor
.
train
(
model
,
optimizer
,
scheduler
,
train_data_loader
,
device
,
writer
,
configs
,
scaler
)
total_loss
,
num_seen_utts
=
executor
.
cv
(
model
,
cv_data_loader
,
device
,
configs
)
cv_loss
=
total_loss
/
num_seen_utts
epoch_time
=
time
.
time
()
-
start
training_only
+=
epoch_time
dist
.
barrier
()
#logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
if
args
.
rank
==
0
:
save_model_path
=
os
.
path
.
join
(
model_dir
,
'{}.pt'
.
format
(
epoch
))
save_checkpoint
(
model
,
save_model_path
,
{
'epoch'
:
epoch
,
'lr'
:
lr
,
'cv_loss'
:
cv_loss
,
'step'
:
executor
.
step
})
#writer.add_scalar('epoch/cv_loss', cv_loss, epoch)
#writer.add_scalar('epoch/lr', lr, epoch)
final_epoch
=
epoch
char_acc
=
0.0
# Run validation by calling run.sh stage=5
# Only run in rank 0
if
args
.
rank
==
0
:
start
=
time
.
time
()
if
final_epoch
is
not
None
:
final_model_path
=
os
.
path
.
join
(
model_dir
,
'final.pt'
)
os
.
remove
(
final_model_path
)
if
os
.
path
.
exists
(
final_model_path
)
else
None
os
.
symlink
(
'{}.pt'
.
format
(
final_epoch
),
final_model_path
)
val_cmd
=
os
.
path
.
join
(
os
.
getcwd
(),
"validate.sh"
)
logging
.
info
(
f
'rank
{
args
.
rank
}
: '
+
"Start validation"
)
os
.
system
(
val_cmd
)
time
.
sleep
(
0.5
)
char_acc
=
compute_char_acc
(
args
)
logging
.
info
(
f
'rank
{
args
.
rank
}
: '
+
"Finish validation"
)
eval_time
=
time
.
time
()
-
start
global_steps
=
get_global_steps
()
eval_output
=
f
'[PerfLog] {{"event": "EVALUATE_END", "value": {{"global_steps":
{
global_steps
}
,"eval_mlm_accuracy":
{
char_acc
:.
4
f
}
,"eval_time":
{
eval_time
:.
2
f
}
,"epoch_time":
{
epoch_time
:.
2
f
}
}}}}'
logging
.
info
(
f
'rank
{
args
.
rank
}
: '
+
eval_output
)
dist
.
barrier
()
torch
.
cuda
.
synchronize
()
t
=
torch
.
tensor
([
char_acc
],
device
=
'cuda'
)
dist
.
broadcast
(
t
,
0
)
char_acc
=
t
[
0
].
item
()
if
char_acc
>=
target_acc
:
final_acc
=
char_acc
break
train_time
=
time
.
time
()
-
training_start
num_trained_samples
=
get_num_trained_samples
()
samples_sec
=
num_trained_samples
/
training_only
train_output
=
f
'[PerfLog] {{"event": "TRAIN_END", "value": {{"accuracy":
{
final_acc
:.
4
f
}
,"train_time":
{
train_time
:.
2
f
}
,"samples/sec:
{
samples_sec
:.
2
f
}
","num_trained_samples":
{
num_trained_samples
}
}}}}'
logging
.
info
(
f
'rank
{
args
.
rank
}
: '
+
train_output
)
if
__name__
==
'__main__'
:
main
()
examples/aishell/s0/wenet/dataset/__pycache__/dataset.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/dataset/__pycache__/processor.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/dataset/dataset.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
random
import
torch
import
torch.distributed
as
dist
from
torch.utils.data
import
IterableDataset
import
wenet.dataset.processor
as
processor
from
wenet.utils.file_utils
import
read_lists
class
Processor
(
IterableDataset
):
def
__init__
(
self
,
source
,
f
,
*
args
,
**
kw
):
assert
callable
(
f
)
self
.
source
=
source
self
.
f
=
f
self
.
args
=
args
self
.
kw
=
kw
def
set_epoch
(
self
,
epoch
):
self
.
source
.
set_epoch
(
epoch
)
def
__iter__
(
self
):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert
self
.
source
is
not
None
assert
callable
(
self
.
f
)
return
self
.
f
(
iter
(
self
.
source
),
*
self
.
args
,
**
self
.
kw
)
def
apply
(
self
,
f
):
assert
callable
(
f
)
return
Processor
(
self
,
f
,
*
self
.
args
,
**
self
.
kw
)
class
DistributedSampler
:
def
__init__
(
self
,
shuffle
=
True
,
partition
=
True
):
self
.
epoch
=
-
1
self
.
update
()
self
.
shuffle
=
shuffle
self
.
partition
=
partition
def
update
(
self
):
assert
dist
.
is_available
()
if
dist
.
is_initialized
():
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
else
:
self
.
rank
=
0
self
.
world_size
=
1
worker_info
=
torch
.
utils
.
data
.
get_worker_info
()
if
worker_info
is
None
:
self
.
worker_id
=
0
self
.
num_workers
=
1
else
:
self
.
worker_id
=
worker_info
.
id
self
.
num_workers
=
worker_info
.
num_workers
return
dict
(
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
worker_id
=
self
.
worker_id
,
num_workers
=
self
.
num_workers
)
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
def
sample
(
self
,
data
):
""" Sample data according to rank/world_size/num_workers
Args:
data(List): input data list
Returns:
List: data list after sample
"""
data
=
list
(
range
(
len
(
data
)))
# TODO(Binbin Zhang): fix this
# We can not handle uneven data for CV on DDP, so we don't
# sample data by rank, that means every GPU gets the same
# and all the CV data
if
self
.
partition
:
if
self
.
shuffle
:
random
.
Random
(
self
.
epoch
).
shuffle
(
data
)
data
=
data
[
self
.
rank
::
self
.
world_size
]
data
=
data
[
self
.
worker_id
::
self
.
num_workers
]
return
data
class
DataList
(
IterableDataset
):
def
__init__
(
self
,
lists
,
shuffle
=
True
,
partition
=
True
):
self
.
lists
=
lists
self
.
sampler
=
DistributedSampler
(
shuffle
,
partition
)
def
set_epoch
(
self
,
epoch
):
self
.
sampler
.
set_epoch
(
epoch
)
def
__iter__
(
self
):
sampler_info
=
self
.
sampler
.
update
()
indexes
=
self
.
sampler
.
sample
(
self
.
lists
)
for
index
in
indexes
:
# yield dict(src=src)
data
=
dict
(
src
=
self
.
lists
[
index
])
data
.
update
(
sampler_info
)
yield
data
def
Dataset
(
data_type
,
data_list_file
,
symbol_table
,
conf
,
bpe_model
=
None
,
non_lang_syms
=
None
,
partition
=
True
):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.
Args:
data_type(str): raw/shard
bpe_model(str): model for english bpe part
partition(bool): whether to do data partition in terms of rank
"""
assert
data_type
in
[
'raw'
,
'shard'
]
lists
=
read_lists
(
data_list_file
)
shuffle
=
conf
.
get
(
'shuffle'
,
True
)
dataset
=
DataList
(
lists
,
shuffle
=
shuffle
,
partition
=
partition
)
if
data_type
==
'shard'
:
dataset
=
Processor
(
dataset
,
processor
.
url_opener
)
dataset
=
Processor
(
dataset
,
processor
.
tar_file_and_group
)
else
:
dataset
=
Processor
(
dataset
,
processor
.
parse_raw
)
dataset
=
Processor
(
dataset
,
processor
.
tokenize
,
symbol_table
,
bpe_model
,
non_lang_syms
,
conf
.
get
(
'split_with_space'
,
False
))
filter_conf
=
conf
.
get
(
'filter_conf'
,
{})
dataset
=
Processor
(
dataset
,
processor
.
filter
,
**
filter_conf
)
resample_conf
=
conf
.
get
(
'resample_conf'
,
{})
dataset
=
Processor
(
dataset
,
processor
.
resample
,
**
resample_conf
)
speed_perturb
=
conf
.
get
(
'speed_perturb'
,
False
)
if
speed_perturb
:
dataset
=
Processor
(
dataset
,
processor
.
speed_perturb
)
feats_type
=
conf
.
get
(
'feats_type'
,
'fbank'
)
assert
feats_type
in
[
'fbank'
,
'mfcc'
]
if
feats_type
==
'fbank'
:
fbank_conf
=
conf
.
get
(
'fbank_conf'
,
{})
dataset
=
Processor
(
dataset
,
processor
.
compute_fbank
,
**
fbank_conf
)
elif
feats_type
==
'mfcc'
:
mfcc_conf
=
conf
.
get
(
'mfcc_conf'
,
{})
dataset
=
Processor
(
dataset
,
processor
.
compute_mfcc
,
**
mfcc_conf
)
spec_aug
=
conf
.
get
(
'spec_aug'
,
True
)
spec_sub
=
conf
.
get
(
'spec_sub'
,
False
)
spec_trim
=
conf
.
get
(
'spec_trim'
,
False
)
if
spec_aug
:
spec_aug_conf
=
conf
.
get
(
'spec_aug_conf'
,
{})
dataset
=
Processor
(
dataset
,
processor
.
spec_aug
,
**
spec_aug_conf
)
if
spec_sub
:
spec_sub_conf
=
conf
.
get
(
'spec_sub_conf'
,
{})
dataset
=
Processor
(
dataset
,
processor
.
spec_sub
,
**
spec_sub_conf
)
if
spec_trim
:
spec_trim_conf
=
conf
.
get
(
'spec_trim_conf'
,
{})
dataset
=
Processor
(
dataset
,
processor
.
spec_trim
,
**
spec_trim_conf
)
if
shuffle
:
shuffle_conf
=
conf
.
get
(
'shuffle_conf'
,
{})
dataset
=
Processor
(
dataset
,
processor
.
shuffle
,
**
shuffle_conf
)
sort
=
conf
.
get
(
'sort'
,
True
)
if
sort
:
sort_conf
=
conf
.
get
(
'sort_conf'
,
{})
dataset
=
Processor
(
dataset
,
processor
.
sort
,
**
sort_conf
)
batch_conf
=
conf
.
get
(
'batch_conf'
,
{})
dataset
=
Processor
(
dataset
,
processor
.
batch
,
**
batch_conf
)
dataset
=
Processor
(
dataset
,
processor
.
padding
)
return
dataset
examples/aishell/s0/wenet/dataset/kaldi_io.py
0 → 100644
View file @
a7785cc6
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2014-2016 Brno University of Technology (author: Karel Vesely)
# Licensed under the Apache License, Version 2.0 (the "License")
import
numpy
as
np
import
sys
,
os
,
re
,
gzip
,
struct
#################################################
# Adding kaldi tools to shell path,
# Select kaldi,
if
not
'KALDI_ROOT'
in
os
.
environ
:
# Default! To change run python with 'export KALDI_ROOT=/some_dir python'
os
.
environ
[
'KALDI_ROOT'
]
=
'/mnt/matylda5/iveselyk/Tools/kaldi-trunk'
# Add kaldi tools to path,
os
.
environ
[
'PATH'
]
=
os
.
popen
(
'echo $KALDI_ROOT/src/bin:$KALDI_ROOT/tools/openfst/bin:$KALDI_ROOT/src/fstbin/:$KALDI_ROOT/src/gmmbin/:$KALDI_ROOT/src/featbin/:$KALDI_ROOT/src/lm/:$KALDI_ROOT/src/sgmmbin/:$KALDI_ROOT/src/sgmm2bin/:$KALDI_ROOT/src/fgmmbin/:$KALDI_ROOT/src/latbin/:$KALDI_ROOT/src/nnetbin:$KALDI_ROOT/src/nnet2bin:$KALDI_ROOT/src/nnet3bin:$KALDI_ROOT/src/online2bin/:$KALDI_ROOT/src/ivectorbin/:$KALDI_ROOT/src/lmbin/'
).
readline
().
strip
()
+
':'
+
os
.
environ
[
'PATH'
]
#################################################
# Define all custom exceptions,
class
UnsupportedDataType
(
Exception
):
pass
class
UnknownVectorHeader
(
Exception
):
pass
class
UnknownMatrixHeader
(
Exception
):
pass
class
BadSampleSize
(
Exception
):
pass
class
BadInputFormat
(
Exception
):
pass
class
SubprocessFailed
(
Exception
):
pass
#################################################
# Data-type independent helper functions,
def
open_or_fd
(
file
,
mode
=
'rb'
):
""" fd = open_or_fd(file)
Open file, gzipped file, pipe, or forward the file-descriptor.
Eventually seeks in the 'file' argument contains ':offset' suffix.
"""
offset
=
None
try
:
# strip 'ark:' prefix from r{x,w}filename (optional),
if
re
.
search
(
'^(ark|scp)(,scp|,b|,t|,n?f|,n?p|,b?o|,n?s|,n?cs)*:'
,
file
):
(
prefix
,
file
)
=
file
.
split
(
':'
,
1
)
# separate offset from filename (optional),
if
re
.
search
(
':[0-9]+$'
,
file
):
(
file
,
offset
)
=
file
.
rsplit
(
':'
,
1
)
# input pipe?
if
file
[
-
1
]
==
'|'
:
fd
=
popen
(
file
[:
-
1
],
'rb'
)
# custom,
# output pipe?
elif
file
[
0
]
==
'|'
:
fd
=
popen
(
file
[
1
:],
'wb'
)
# custom,
# is it gzipped?
elif
file
.
split
(
'.'
)[
-
1
]
==
'gz'
:
fd
=
gzip
.
open
(
file
,
mode
)
# a normal file...
else
:
fd
=
open
(
file
,
mode
)
except
TypeError
:
# 'file' is opened file descriptor,
fd
=
file
# Eventually seek to offset,
if
offset
!=
None
:
fd
.
seek
(
int
(
offset
))
return
fd
# based on '/usr/local/lib/python3.4/os.py'
def
popen
(
cmd
,
mode
=
"rb"
):
if
not
isinstance
(
cmd
,
str
):
raise
TypeError
(
"invalid cmd type (%s, expected string)"
%
type
(
cmd
))
import
subprocess
,
io
,
threading
# cleanup function for subprocesses,
def
cleanup
(
proc
,
cmd
):
ret
=
proc
.
wait
()
if
ret
>
0
:
raise
SubprocessFailed
(
'cmd %s returned %d !'
%
(
cmd
,
ret
))
return
# text-mode,
if
mode
==
"r"
:
proc
=
subprocess
.
Popen
(
cmd
,
shell
=
True
,
stdout
=
subprocess
.
PIPE
)
threading
.
Thread
(
target
=
cleanup
,
args
=
(
proc
,
cmd
)).
start
()
# clean-up thread,
return
io
.
TextIOWrapper
(
proc
.
stdout
)
elif
mode
==
"w"
:
proc
=
subprocess
.
Popen
(
cmd
,
shell
=
True
,
stdin
=
subprocess
.
PIPE
)
threading
.
Thread
(
target
=
cleanup
,
args
=
(
proc
,
cmd
)).
start
()
# clean-up thread,
return
io
.
TextIOWrapper
(
proc
.
stdin
)
# binary,
elif
mode
==
"rb"
:
proc
=
subprocess
.
Popen
(
cmd
,
shell
=
True
,
stdout
=
subprocess
.
PIPE
)
threading
.
Thread
(
target
=
cleanup
,
args
=
(
proc
,
cmd
)).
start
()
# clean-up thread,
return
proc
.
stdout
elif
mode
==
"wb"
:
proc
=
subprocess
.
Popen
(
cmd
,
shell
=
True
,
stdin
=
subprocess
.
PIPE
)
threading
.
Thread
(
target
=
cleanup
,
args
=
(
proc
,
cmd
)).
start
()
# clean-up thread,
return
proc
.
stdin
# sanity,
else
:
raise
ValueError
(
"invalid mode %s"
%
mode
)
def
read_key
(
fd
):
""" [key] = read_key(fd)
Read the utterance-key from the opened ark/stream descriptor 'fd'.
"""
key
=
''
while
1
:
char
=
fd
.
read
(
1
).
decode
(
"latin1"
)
if
char
==
''
:
break
if
char
==
' '
:
break
key
+=
char
key
=
key
.
strip
()
if
key
==
''
:
return
None
# end of file,
assert
(
re
.
match
(
'^\S+$'
,
key
)
!=
None
)
# check format (no whitespace!)
return
key
#################################################
# Integer vectors (alignments, ...),
def
read_ali_ark
(
file_or_fd
):
""" Alias to 'read_vec_int_ark()' """
return
read_vec_int_ark
(
file_or_fd
)
def
read_vec_int_ark
(
file_or_fd
):
""" generator(key,vec) = read_vec_int_ark(file_or_fd)
Create generator of (key,vector<int>) tuples, which reads from the ark file/stream.
file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
Read ark to a 'dictionary':
d = { u:d for u,d in kaldi_io.read_vec_int_ark(file) }
"""
fd
=
open_or_fd
(
file_or_fd
)
try
:
key
=
read_key
(
fd
)
while
key
:
ali
=
read_vec_int
(
fd
)
yield
key
,
ali
key
=
read_key
(
fd
)
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
def
read_vec_int_scp
(
file_or_fd
):
""" generator(key,vec) = read_vec_int_scp(file_or_fd)
Returns generator of (key,vector<int>) tuples, read according to kaldi scp.
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
Iterate the scp:
for key,vec in kaldi_io.read_vec_int_scp(file):
...
Read scp to a 'dictionary':
d = { key:vec for key,mat in kaldi_io.read_vec_int_scp(file) }
"""
fd
=
open_or_fd
(
file_or_fd
)
try
:
for
line
in
fd
:
(
key
,
rxfile
)
=
line
.
decode
().
split
(
' '
)
vec
=
read_vec_int
(
rxfile
)
yield
key
,
vec
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
def
read_vec_int
(
file_or_fd
):
""" [int-vec] = read_vec_int(file_or_fd)
Read kaldi integer vector, ascii or binary input,
"""
fd
=
open_or_fd
(
file_or_fd
)
binary
=
fd
.
read
(
2
).
decode
()
if
binary
==
'
\0
B'
:
# binary flag
assert
(
fd
.
read
(
1
).
decode
()
==
'
\4
'
);
# int-size
vec_size
=
np
.
frombuffer
(
fd
.
read
(
4
),
dtype
=
'int32'
,
count
=
1
)[
0
]
# vector dim
# Elements from int32 vector are sored in tuples: (sizeof(int32), value),
vec
=
np
.
frombuffer
(
fd
.
read
(
vec_size
*
5
),
dtype
=
[(
'size'
,
'int8'
),(
'value'
,
'int32'
)],
count
=
vec_size
)
assert
(
vec
[
0
][
'size'
]
==
4
)
# int32 size,
ans
=
vec
[:][
'value'
]
# values are in 2nd column,
else
:
# ascii,
arr
=
(
binary
+
fd
.
readline
().
decode
()).
strip
().
split
()
try
:
arr
.
remove
(
'['
);
arr
.
remove
(
']'
)
# optionally
except
ValueError
:
pass
ans
=
np
.
array
(
arr
,
dtype
=
int
)
if
fd
is
not
file_or_fd
:
fd
.
close
()
# cleanup
return
ans
# Writing,
def
write_vec_int
(
file_or_fd
,
v
,
key
=
''
):
""" write_vec_int(f, v, key='')
Write a binary kaldi integer vector to filename or stream.
Arguments:
file_or_fd : filename or opened file descriptor for writing,
v : the vector to be stored,
key (optional) : used for writing ark-file, the utterance-id gets written before the vector.
Example of writing single vector:
kaldi_io.write_vec_int(filename, vec)
Example of writing arkfile:
with open(ark_file,'w') as f:
for key,vec in dict.iteritems():
kaldi_io.write_vec_flt(f, vec, key=key)
"""
fd
=
open_or_fd
(
file_or_fd
,
mode
=
'wb'
)
if
sys
.
version_info
[
0
]
==
3
:
assert
(
fd
.
mode
==
'wb'
)
try
:
if
key
!=
''
:
fd
.
write
((
key
+
' '
).
encode
(
"latin1"
))
# ark-files have keys (utterance-id),
fd
.
write
(
'
\0
B'
.
encode
())
# we write binary!
# dim,
fd
.
write
(
'
\4
'
.
encode
())
# int32 type,
fd
.
write
(
struct
.
pack
(
np
.
dtype
(
'int32'
).
char
,
v
.
shape
[
0
]))
# data,
for
i
in
range
(
len
(
v
)):
fd
.
write
(
'
\4
'
.
encode
())
# int32 type,
fd
.
write
(
struct
.
pack
(
np
.
dtype
(
'int32'
).
char
,
v
[
i
]))
# binary,
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
#################################################
# Float vectors (confidences, ivectors, ...),
# Reading,
def
read_vec_flt_scp
(
file_or_fd
):
""" generator(key,mat) = read_vec_flt_scp(file_or_fd)
Returns generator of (key,vector) tuples, read according to kaldi scp.
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
Iterate the scp:
for key,vec in kaldi_io.read_vec_flt_scp(file):
...
Read scp to a 'dictionary':
d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) }
"""
fd
=
open_or_fd
(
file_or_fd
)
try
:
for
line
in
fd
:
(
key
,
rxfile
)
=
line
.
decode
().
split
(
' '
)
vec
=
read_vec_flt
(
rxfile
)
yield
key
,
vec
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
def
read_vec_flt_ark
(
file_or_fd
):
""" generator(key,vec) = read_vec_flt_ark(file_or_fd)
Create generator of (key,vector<float>) tuples, reading from an ark file/stream.
file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
Read ark to a 'dictionary':
d = { u:d for u,d in kaldi_io.read_vec_flt_ark(file) }
"""
fd
=
open_or_fd
(
file_or_fd
)
try
:
key
=
read_key
(
fd
)
while
key
:
ali
=
read_vec_flt
(
fd
)
yield
key
,
ali
key
=
read_key
(
fd
)
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
def
read_vec_flt
(
file_or_fd
):
""" [flt-vec] = read_vec_flt(file_or_fd)
Read kaldi float vector, ascii or binary input,
"""
fd
=
open_or_fd
(
file_or_fd
)
binary
=
fd
.
read
(
2
).
decode
()
if
binary
==
'
\0
B'
:
# binary flag
# Data type,
header
=
fd
.
read
(
3
).
decode
()
if
header
==
'FV '
:
sample_size
=
4
# floats
elif
header
==
'DV '
:
sample_size
=
8
# doubles
else
:
raise
UnknownVectorHeader
(
"The header contained '%s'"
%
header
)
assert
(
sample_size
>
0
)
# Dimension,
assert
(
fd
.
read
(
1
).
decode
()
==
'
\4
'
);
# int-size
vec_size
=
np
.
frombuffer
(
fd
.
read
(
4
),
dtype
=
'int32'
,
count
=
1
)[
0
]
# vector dim
# Read whole vector,
buf
=
fd
.
read
(
vec_size
*
sample_size
)
if
sample_size
==
4
:
ans
=
np
.
frombuffer
(
buf
,
dtype
=
'float32'
)
elif
sample_size
==
8
:
ans
=
np
.
frombuffer
(
buf
,
dtype
=
'float64'
)
else
:
raise
BadSampleSize
return
ans
else
:
# ascii,
arr
=
(
binary
+
fd
.
readline
().
decode
()).
strip
().
split
()
try
:
arr
.
remove
(
'['
);
arr
.
remove
(
']'
)
# optionally
except
ValueError
:
pass
ans
=
np
.
array
(
arr
,
dtype
=
float
)
if
fd
is
not
file_or_fd
:
fd
.
close
()
# cleanup
return
ans
# Writing,
def
write_vec_flt
(
file_or_fd
,
v
,
key
=
''
):
""" write_vec_flt(f, v, key='')
Write a binary kaldi vector to filename or stream. Supports 32bit and 64bit floats.
Arguments:
file_or_fd : filename or opened file descriptor for writing,
v : the vector to be stored,
key (optional) : used for writing ark-file, the utterance-id gets written before the vector.
Example of writing single vector:
kaldi_io.write_vec_flt(filename, vec)
Example of writing arkfile:
with open(ark_file,'w') as f:
for key,vec in dict.iteritems():
kaldi_io.write_vec_flt(f, vec, key=key)
"""
fd
=
open_or_fd
(
file_or_fd
,
mode
=
'wb'
)
if
sys
.
version_info
[
0
]
==
3
:
assert
(
fd
.
mode
==
'wb'
)
try
:
if
key
!=
''
:
fd
.
write
((
key
+
' '
).
encode
(
"latin1"
))
# ark-files have keys (utterance-id),
fd
.
write
(
'
\0
B'
.
encode
())
# we write binary!
# Data-type,
if
v
.
dtype
==
'float32'
:
fd
.
write
(
'FV '
.
encode
())
elif
v
.
dtype
==
'float64'
:
fd
.
write
(
'DV '
.
encode
())
else
:
raise
UnsupportedDataType
(
"'%s', please use 'float32' or 'float64'"
%
v
.
dtype
)
# Dim,
fd
.
write
(
'
\04
'
.
encode
())
fd
.
write
(
struct
.
pack
(
np
.
dtype
(
'uint32'
).
char
,
v
.
shape
[
0
]))
# dim
# Data,
fd
.
write
(
v
.
tobytes
())
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
#################################################
# Float matrices (features, transformations, ...),
# Reading,
def
read_mat_scp
(
file_or_fd
):
""" generator(key,mat) = read_mat_scp(file_or_fd)
Returns generator of (key,matrix) tuples, read according to kaldi scp.
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
Iterate the scp:
for key,mat in kaldi_io.read_mat_scp(file):
...
Read scp to a 'dictionary':
d = { key:mat for key,mat in kaldi_io.read_mat_scp(file) }
"""
fd
=
open_or_fd
(
file_or_fd
)
try
:
for
line
in
fd
:
(
key
,
rxfile
)
=
line
.
decode
().
split
(
' '
)
mat
=
read_mat
(
rxfile
)
yield
key
,
mat
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
def
read_mat_ark
(
file_or_fd
):
""" generator(key,mat) = read_mat_ark(file_or_fd)
Returns generator of (key,matrix) tuples, read from ark file/stream.
file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
Iterate the ark:
for key,mat in kaldi_io.read_mat_ark(file):
...
Read ark to a 'dictionary':
d = { key:mat for key,mat in kaldi_io.read_mat_ark(file) }
"""
fd
=
open_or_fd
(
file_or_fd
)
try
:
key
=
read_key
(
fd
)
while
key
:
mat
=
read_mat
(
fd
)
yield
key
,
mat
key
=
read_key
(
fd
)
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
def
read_mat
(
file_or_fd
):
""" [mat] = read_mat(file_or_fd)
Reads single kaldi matrix, supports ascii and binary.
file_or_fd : file, gzipped file, pipe or opened file descriptor.
"""
fd
=
open_or_fd
(
file_or_fd
)
try
:
binary
=
fd
.
read
(
2
).
decode
()
if
binary
==
'
\0
B'
:
mat
=
_read_mat_binary
(
fd
)
else
:
assert
(
binary
==
' ['
)
mat
=
_read_mat_ascii
(
fd
)
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
return
mat
def
_read_mat_binary
(
fd
):
# Data type
header
=
fd
.
read
(
3
).
decode
()
# 'CM', 'CM2', 'CM3' are possible values,
if
header
.
startswith
(
'CM'
):
return
_read_compressed_mat
(
fd
,
header
)
elif
header
==
'FM '
:
sample_size
=
4
# floats
elif
header
==
'DM '
:
sample_size
=
8
# doubles
else
:
raise
UnknownMatrixHeader
(
"The header contained '%s'"
%
header
)
assert
(
sample_size
>
0
)
# Dimensions
s1
,
rows
,
s2
,
cols
=
np
.
frombuffer
(
fd
.
read
(
10
),
dtype
=
'int8,int32,int8,int32'
,
count
=
1
)[
0
]
# Read whole matrix
buf
=
fd
.
read
(
rows
*
cols
*
sample_size
)
if
sample_size
==
4
:
vec
=
np
.
frombuffer
(
buf
,
dtype
=
'float32'
)
elif
sample_size
==
8
:
vec
=
np
.
frombuffer
(
buf
,
dtype
=
'float64'
)
else
:
raise
BadSampleSize
mat
=
np
.
reshape
(
vec
,(
rows
,
cols
))
return
mat
def
_read_mat_ascii
(
fd
):
rows
=
[]
while
1
:
line
=
fd
.
readline
().
decode
()
if
(
len
(
line
)
==
0
)
:
raise
BadInputFormat
# eof, should not happen!
if
len
(
line
.
strip
())
==
0
:
continue
# skip empty line
arr
=
line
.
strip
().
split
()
if
arr
[
-
1
]
!=
']'
:
rows
.
append
(
np
.
array
(
arr
,
dtype
=
'float32'
))
# not last line
else
:
rows
.
append
(
np
.
array
(
arr
[:
-
1
],
dtype
=
'float32'
))
# last line
mat
=
np
.
vstack
(
rows
)
return
mat
def
_read_compressed_mat
(
fd
,
format
):
""" Read a compressed matrix,
see: https://github.com/kaldi-asr/kaldi/blob/master/src/matrix/compressed-matrix.h
methods: CompressedMatrix::Read(...), CompressedMatrix::CopyToMat(...),
"""
assert
(
format
==
'CM '
)
# The formats CM2, CM3 are not supported...
# Format of header 'struct',
global_header
=
np
.
dtype
([(
'minvalue'
,
'float32'
),(
'range'
,
'float32'
),(
'num_rows'
,
'int32'
),(
'num_cols'
,
'int32'
)])
# member '.format' is not written,
per_col_header
=
np
.
dtype
([(
'percentile_0'
,
'uint16'
),(
'percentile_25'
,
'uint16'
),(
'percentile_75'
,
'uint16'
),(
'percentile_100'
,
'uint16'
)])
# Mapping for percentiles in col-headers,
def
uint16_to_float
(
value
,
min
,
range
):
return
np
.
float32
(
min
+
range
*
1.52590218966964e-05
*
value
)
# Mapping for matrix elements,
def
uint8_to_float_v2
(
vec
,
p0
,
p25
,
p75
,
p100
):
# Split the vector by masks,
mask_0_64
=
(
vec
<=
64
);
mask_193_255
=
(
vec
>
192
);
mask_65_192
=
(
~
(
mask_0_64
|
mask_193_255
));
# Sanity check (useful but slow...),
# assert(len(vec) == np.sum(np.hstack([mask_0_64,mask_65_192,mask_193_255])))
# assert(len(vec) == np.sum(np.any([mask_0_64,mask_65_192,mask_193_255], axis=0)))
# Build the float vector,
ans
=
np
.
empty
(
len
(
vec
),
dtype
=
'float32'
)
ans
[
mask_0_64
]
=
p0
+
(
p25
-
p0
)
/
64.
*
vec
[
mask_0_64
]
ans
[
mask_65_192
]
=
p25
+
(
p75
-
p25
)
/
128.
*
(
vec
[
mask_65_192
]
-
64
)
ans
[
mask_193_255
]
=
p75
+
(
p100
-
p75
)
/
63.
*
(
vec
[
mask_193_255
]
-
192
)
return
ans
# Read global header,
globmin
,
globrange
,
rows
,
cols
=
np
.
frombuffer
(
fd
.
read
(
16
),
dtype
=
global_header
,
count
=
1
)[
0
]
# The data is structed as [Colheader, ... , Colheader, Data, Data , .... ]
# { cols }{ size }
col_headers
=
np
.
frombuffer
(
fd
.
read
(
cols
*
8
),
dtype
=
per_col_header
,
count
=
cols
)
data
=
np
.
reshape
(
np
.
frombuffer
(
fd
.
read
(
cols
*
rows
),
dtype
=
'uint8'
,
count
=
cols
*
rows
),
newshape
=
(
cols
,
rows
))
# stored as col-major,
mat
=
np
.
empty
((
cols
,
rows
),
dtype
=
'float32'
)
for
i
,
col_header
in
enumerate
(
col_headers
):
col_header_flt
=
[
uint16_to_float
(
percentile
,
globmin
,
globrange
)
for
percentile
in
col_header
]
mat
[
i
]
=
uint8_to_float_v2
(
data
[
i
],
*
col_header_flt
)
return
mat
.
T
# transpose! col-major -> row-major,
def
write_ark_scp
(
key
,
mat
,
ark_fout
,
scp_out
):
mat_offset
=
write_mat
(
ark_fout
,
mat
,
key
)
scp_line
=
'{}
\t
{}:{}'
.
format
(
key
,
ark_fout
.
name
,
mat_offset
)
scp_out
.
write
(
scp_line
)
scp_out
.
write
(
'
\n
'
)
# Writing,
def
write_mat
(
file_or_fd
,
m
,
key
=
''
):
""" write_mat(f, m, key='')
Write a binary kaldi matrix to filename or stream. Supports 32bit and 64bit floats.
Arguments:
file_or_fd : filename of opened file descriptor for writing,
m : the matrix to be stored,
key (optional) : used for writing ark-file, the utterance-id gets written before the matrix.
Example of writing single matrix:
kaldi_io.write_mat(filename, mat)
Example of writing arkfile:
with open(ark_file,'w') as f:
for key,mat in dict.iteritems():
kaldi_io.write_mat(f, mat, key=key)
"""
mat_offset
=
0
fd
=
open_or_fd
(
file_or_fd
,
mode
=
'wb'
)
if
sys
.
version_info
[
0
]
==
3
:
assert
(
fd
.
mode
==
'wb'
)
try
:
if
key
!=
''
:
fd
.
write
((
key
+
' '
).
encode
(
"latin1"
))
# ark-files have keys (utterance-id),
mat_offset
=
fd
.
tell
()
fd
.
write
(
'
\0
B'
.
encode
())
# we write binary!
# Data-type,
if
m
.
dtype
==
'float32'
:
fd
.
write
(
'FM '
.
encode
())
elif
m
.
dtype
==
'float64'
:
fd
.
write
(
'DM '
.
encode
())
else
:
raise
UnsupportedDataType
(
"'%s', please use 'float32' or 'float64'"
%
m
.
dtype
)
# Dims,
fd
.
write
(
'
\04
'
.
encode
())
fd
.
write
(
struct
.
pack
(
np
.
dtype
(
'uint32'
).
char
,
m
.
shape
[
0
]))
# rows
fd
.
write
(
'
\04
'
.
encode
())
fd
.
write
(
struct
.
pack
(
np
.
dtype
(
'uint32'
).
char
,
m
.
shape
[
1
]))
# cols
# Data,
fd
.
write
(
m
.
tobytes
())
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
return
mat_offset
#################################################
# 'Posterior' kaldi type (posteriors, confusion network, nnet1 training targets, ...)
# Corresponds to: vector<vector<tuple<int,float> > >
# - outer vector: time axis
# - inner vector: records at the time
# - tuple: int = index, float = value
#
def
read_cnet_ark
(
file_or_fd
):
""" Alias of function 'read_post_ark()', 'cnet' = confusion network """
return
read_post_ark
(
file_or_fd
)
def
read_post_ark
(
file_or_fd
):
""" generator(key,vec<vec<int,float>>) = read_post_ark(file)
Returns generator of (key,posterior) tuples, read from ark file.
file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
Iterate the ark:
for key,post in kaldi_io.read_post_ark(file):
...
Read ark to a 'dictionary':
d = { key:post for key,post in kaldi_io.read_post_ark(file) }
"""
fd
=
open_or_fd
(
file_or_fd
)
try
:
key
=
read_key
(
fd
)
while
key
:
post
=
read_post
(
fd
)
yield
key
,
post
key
=
read_key
(
fd
)
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
def
read_post
(
file_or_fd
):
""" [post] = read_post(file_or_fd)
Reads single kaldi 'Posterior' in binary format.
The 'Posterior' is C++ type 'vector<vector<tuple<int,float> > >',
the outer-vector is usually time axis, inner-vector are the records
at given time, and the tuple is composed of an 'index' (integer)
and a 'float-value'. The 'float-value' can represent a probability
or any other numeric value.
Returns vector of vectors of tuples.
"""
fd
=
open_or_fd
(
file_or_fd
)
ans
=
[]
binary
=
fd
.
read
(
2
).
decode
();
assert
(
binary
==
'
\0
B'
);
# binary flag
assert
(
fd
.
read
(
1
).
decode
()
==
'
\4
'
);
# int-size
outer_vec_size
=
np
.
frombuffer
(
fd
.
read
(
4
),
dtype
=
'int32'
,
count
=
1
)[
0
]
# number of frames (or bins)
# Loop over 'outer-vector',
for
i
in
range
(
outer_vec_size
):
assert
(
fd
.
read
(
1
).
decode
()
==
'
\4
'
);
# int-size
inner_vec_size
=
np
.
frombuffer
(
fd
.
read
(
4
),
dtype
=
'int32'
,
count
=
1
)[
0
]
# number of records for frame (or bin)
data
=
np
.
frombuffer
(
fd
.
read
(
inner_vec_size
*
10
),
dtype
=
[(
'size_idx'
,
'int8'
),(
'idx'
,
'int32'
),(
'size_post'
,
'int8'
),(
'post'
,
'float32'
)],
count
=
inner_vec_size
)
assert
(
data
[
0
][
'size_idx'
]
==
4
)
assert
(
data
[
0
][
'size_post'
]
==
4
)
ans
.
append
(
data
[[
'idx'
,
'post'
]].
tolist
())
if
fd
is
not
file_or_fd
:
fd
.
close
()
return
ans
#################################################
# Kaldi Confusion Network bin begin/end times,
# (kaldi stores CNs time info separately from the Posterior).
#
def
read_cntime_ark
(
file_or_fd
):
""" generator(key,vec<tuple<float,float>>) = read_cntime_ark(file_or_fd)
Returns generator of (key,cntime) tuples, read from ark file.
file_or_fd : file, gzipped file, pipe or opened file descriptor.
Iterate the ark:
for key,time in kaldi_io.read_cntime_ark(file):
...
Read ark to a 'dictionary':
d = { key:time for key,time in kaldi_io.read_post_ark(file) }
"""
fd
=
open_or_fd
(
file_or_fd
)
try
:
key
=
read_key
(
fd
)
while
key
:
cntime
=
read_cntime
(
fd
)
yield
key
,
cntime
key
=
read_key
(
fd
)
finally
:
if
fd
is
not
file_or_fd
:
fd
.
close
()
def
read_cntime
(
file_or_fd
):
""" [cntime] = read_cntime(file_or_fd)
Reads single kaldi 'Confusion Network time info', in binary format:
C++ type: vector<tuple<float,float> >.
(begin/end times of bins at the confusion network).
Binary layout is '<num-bins> <beg1> <end1> <beg2> <end2> ...'
file_or_fd : file, gzipped file, pipe or opened file descriptor.
Returns vector of tuples.
"""
fd
=
open_or_fd
(
file_or_fd
)
binary
=
fd
.
read
(
2
).
decode
();
assert
(
binary
==
'
\0
B'
);
# assuming it's binary
assert
(
fd
.
read
(
1
).
decode
()
==
'
\4
'
);
# int-size
vec_size
=
np
.
frombuffer
(
fd
.
read
(
4
),
dtype
=
'int32'
,
count
=
1
)[
0
]
# number of frames (or bins)
data
=
np
.
frombuffer
(
fd
.
read
(
vec_size
*
10
),
dtype
=
[(
'size_beg'
,
'int8'
),(
't_beg'
,
'float32'
),(
'size_end'
,
'int8'
),(
't_end'
,
'float32'
)],
count
=
vec_size
)
assert
(
data
[
0
][
'size_beg'
]
==
4
)
assert
(
data
[
0
][
'size_end'
]
==
4
)
ans
=
data
[[
't_beg'
,
't_end'
]].
tolist
()
# Return vector of tuples (t_beg,t_end),
if
fd
is
not
file_or_fd
:
fd
.
close
()
return
ans
#################################################
# Segments related,
#
# Segments as 'Bool vectors' can be handy,
# - for 'superposing' the segmentations,
# - for frame-selection in Speaker-ID experiments,
def
read_segments_as_bool_vec
(
segments_file
):
""" [ bool_vec ] = read_segments_as_bool_vec(segments_file)
using kaldi 'segments' file for 1 wav, format : '<utt> <rec> <t-beg> <t-end>'
- t-beg, t-end is in seconds,
- assumed 100 frames/second,
"""
segs
=
np
.
loadtxt
(
segments_file
,
dtype
=
'object,object,f,f'
,
ndmin
=
1
)
# Sanity checks,
assert
(
len
(
segs
)
>
0
)
# empty segmentation is an error,
assert
(
len
(
np
.
unique
([
rec
[
1
]
for
rec
in
segs
]))
==
1
)
# segments with only 1 wav-file,
# Convert time to frame-indexes,
start
=
np
.
rint
([
100
*
rec
[
2
]
for
rec
in
segs
]).
astype
(
int
)
end
=
np
.
rint
([
100
*
rec
[
3
]
for
rec
in
segs
]).
astype
(
int
)
# Taken from 'read_lab_to_bool_vec', htk.py,
frms
=
np
.
repeat
(
np
.
r_
[
np
.
tile
([
False
,
True
],
len
(
end
)),
False
],
np
.
r_
[
np
.
c_
[
start
-
np
.
r_
[
0
,
end
[:
-
1
]],
end
-
start
].
flat
,
0
])
assert
np
.
sum
(
end
-
start
)
==
np
.
sum
(
frms
)
return
frms
examples/aishell/s0/wenet/dataset/processor.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
json
import
random
import
re
import
tarfile
from
subprocess
import
PIPE
,
Popen
from
urllib.parse
import
urlparse
import
torch
import
torchaudio
import
torchaudio.compliance.kaldi
as
kaldi
from
torch.nn.utils.rnn
import
pad_sequence
AUDIO_FORMAT_SETS
=
set
([
'flac'
,
'mp3'
,
'm4a'
,
'ogg'
,
'opus'
,
'wav'
,
'wma'
])
def
url_opener
(
data
):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for
sample
in
data
:
assert
'src'
in
sample
# TODO(Binbin Zhang): support HTTP
url
=
sample
[
'src'
]
try
:
pr
=
urlparse
(
url
)
# local file
if
pr
.
scheme
==
''
or
pr
.
scheme
==
'file'
:
stream
=
open
(
url
,
'rb'
)
# network file, such as HTTP(HDFS/OSS/S3)/HTTPS/SCP
else
:
cmd
=
f
'wget -q -O -
{
url
}
'
process
=
Popen
(
cmd
,
shell
=
True
,
stdout
=
PIPE
)
sample
.
update
(
process
=
process
)
stream
=
process
.
stdout
sample
.
update
(
stream
=
stream
)
yield
sample
except
Exception
as
ex
:
logging
.
warning
(
'Failed to open {}'
.
format
(
url
))
def
tar_file_and_group
(
data
):
""" Expand a stream of open tar files into a stream of tar file contents.
And groups the file with same prefix
Args:
data: Iterable[{src, stream}]
Returns:
Iterable[{key, wav, txt, sample_rate}]
"""
for
sample
in
data
:
assert
'stream'
in
sample
stream
=
tarfile
.
open
(
fileobj
=
sample
[
'stream'
],
mode
=
"r|*"
)
prev_prefix
=
None
example
=
{}
valid
=
True
for
tarinfo
in
stream
:
name
=
tarinfo
.
name
pos
=
name
.
rfind
(
'.'
)
assert
pos
>
0
prefix
,
postfix
=
name
[:
pos
],
name
[
pos
+
1
:]
if
prev_prefix
is
not
None
and
prefix
!=
prev_prefix
:
example
[
'key'
]
=
prev_prefix
if
valid
:
yield
example
example
=
{}
valid
=
True
with
stream
.
extractfile
(
tarinfo
)
as
file_obj
:
try
:
if
postfix
==
'txt'
:
example
[
'txt'
]
=
file_obj
.
read
().
decode
(
'utf8'
).
strip
()
elif
postfix
in
AUDIO_FORMAT_SETS
:
waveform
,
sample_rate
=
torchaudio
.
load
(
file_obj
)
example
[
'wav'
]
=
waveform
example
[
'sample_rate'
]
=
sample_rate
else
:
example
[
postfix
]
=
file_obj
.
read
()
except
Exception
as
ex
:
valid
=
False
logging
.
warning
(
'error to parse {}'
.
format
(
name
))
prev_prefix
=
prefix
if
prev_prefix
is
not
None
:
example
[
'key'
]
=
prev_prefix
yield
example
stream
.
close
()
if
'process'
in
sample
:
sample
[
'process'
].
communicate
()
sample
[
'stream'
].
close
()
def
parse_raw
(
data
):
""" Parse key/wav/txt from json line
Args:
data: Iterable[str], str is a json line has key/wav/txt
Returns:
Iterable[{key, wav, txt, sample_rate}]
"""
for
sample
in
data
:
assert
'src'
in
sample
json_line
=
sample
[
'src'
]
obj
=
json
.
loads
(
json_line
)
assert
'key'
in
obj
assert
'wav'
in
obj
assert
'txt'
in
obj
key
=
obj
[
'key'
]
wav_file
=
obj
[
'wav'
]
txt
=
obj
[
'txt'
]
try
:
if
'start'
in
obj
:
assert
'end'
in
obj
sample_rate
=
torchaudio
.
backend
.
sox_io_backend
.
info
(
wav_file
).
sample_rate
start_frame
=
int
(
obj
[
'start'
]
*
sample_rate
)
end_frame
=
int
(
obj
[
'end'
]
*
sample_rate
)
waveform
,
_
=
torchaudio
.
backend
.
sox_io_backend
.
load
(
filepath
=
wav_file
,
num_frames
=
end_frame
-
start_frame
,
frame_offset
=
start_frame
)
else
:
waveform
,
sample_rate
=
torchaudio
.
load
(
wav_file
)
example
=
dict
(
key
=
key
,
txt
=
txt
,
wav
=
waveform
,
sample_rate
=
sample_rate
)
yield
example
except
Exception
as
ex
:
logging
.
warning
(
'Failed to read {}'
.
format
(
wav_file
))
def
filter
(
data
,
max_length
=
10240
,
min_length
=
10
,
token_max_length
=
200
,
token_min_length
=
1
,
min_output_input_ratio
=
0.0005
,
max_output_input_ratio
=
1
):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for
sample
in
data
:
assert
'sample_rate'
in
sample
assert
'wav'
in
sample
assert
'label'
in
sample
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames
=
sample
[
'wav'
].
size
(
1
)
/
sample
[
'sample_rate'
]
*
100
if
num_frames
<
min_length
:
continue
if
num_frames
>
max_length
:
continue
if
len
(
sample
[
'label'
])
<
token_min_length
:
continue
if
len
(
sample
[
'label'
])
>
token_max_length
:
continue
if
num_frames
!=
0
:
if
len
(
sample
[
'label'
])
/
num_frames
<
min_output_input_ratio
:
continue
if
len
(
sample
[
'label'
])
/
num_frames
>
max_output_input_ratio
:
continue
yield
sample
def
resample
(
data
,
resample_rate
=
16000
):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for
sample
in
data
:
assert
'sample_rate'
in
sample
assert
'wav'
in
sample
sample_rate
=
sample
[
'sample_rate'
]
waveform
=
sample
[
'wav'
]
if
sample_rate
!=
resample_rate
:
sample
[
'sample_rate'
]
=
resample_rate
sample
[
'wav'
]
=
torchaudio
.
transforms
.
Resample
(
orig_freq
=
sample_rate
,
new_freq
=
resample_rate
)(
waveform
)
yield
sample
def
speed_perturb
(
data
,
speeds
=
None
):
""" Apply speed perturb to the data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
speeds(List[float]): optional speed
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
if
speeds
is
None
:
speeds
=
[
0.9
,
1.0
,
1.1
]
for
sample
in
data
:
assert
'sample_rate'
in
sample
assert
'wav'
in
sample
sample_rate
=
sample
[
'sample_rate'
]
waveform
=
sample
[
'wav'
]
speed
=
random
.
choice
(
speeds
)
if
speed
!=
1.0
:
wav
,
_
=
torchaudio
.
sox_effects
.
apply_effects_tensor
(
waveform
,
sample_rate
,
[[
'speed'
,
str
(
speed
)],
[
'rate'
,
str
(
sample_rate
)]])
sample
[
'wav'
]
=
wav
yield
sample
def
compute_fbank
(
data
,
num_mel_bins
=
23
,
frame_length
=
25
,
frame_shift
=
10
,
dither
=
0.0
):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for
sample
in
data
:
assert
'sample_rate'
in
sample
assert
'wav'
in
sample
assert
'key'
in
sample
assert
'label'
in
sample
sample_rate
=
sample
[
'sample_rate'
]
waveform
=
sample
[
'wav'
]
waveform
=
waveform
*
(
1
<<
15
)
# Only keep key, feat, label
mat
=
kaldi
.
fbank
(
waveform
,
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
,
energy_floor
=
0.0
,
sample_frequency
=
sample_rate
)
yield
dict
(
key
=
sample
[
'key'
],
label
=
sample
[
'label'
],
feat
=
mat
)
def
compute_mfcc
(
data
,
num_mel_bins
=
23
,
frame_length
=
25
,
frame_shift
=
10
,
dither
=
0.0
,
num_ceps
=
40
,
high_freq
=
0.0
,
low_freq
=
20.0
):
""" Extract mfcc
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for
sample
in
data
:
assert
'sample_rate'
in
sample
assert
'wav'
in
sample
assert
'key'
in
sample
assert
'label'
in
sample
sample_rate
=
sample
[
'sample_rate'
]
waveform
=
sample
[
'wav'
]
waveform
=
waveform
*
(
1
<<
15
)
# Only keep key, feat, label
mat
=
kaldi
.
mfcc
(
waveform
,
num_mel_bins
=
num_mel_bins
,
frame_length
=
frame_length
,
frame_shift
=
frame_shift
,
dither
=
dither
,
num_ceps
=
num_ceps
,
high_freq
=
high_freq
,
low_freq
=
low_freq
,
sample_frequency
=
sample_rate
)
yield
dict
(
key
=
sample
[
'key'
],
label
=
sample
[
'label'
],
feat
=
mat
)
def
__tokenize_by_bpe_model
(
sp
,
txt
):
tokens
=
[]
# CJK(China Japan Korea) unicode range is [U+4E00, U+9FFF], ref:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
pattern
=
re
.
compile
(
r
'([\u4e00-\u9fff])'
)
# Example:
# txt = "你好 ITS'S OKAY 的"
# chars = ["你", "好", " ITS'S OKAY ", "的"]
chars
=
pattern
.
split
(
txt
.
upper
())
mix_chars
=
[
w
for
w
in
chars
if
len
(
w
.
strip
())
>
0
]
for
ch_or_w
in
mix_chars
:
# ch_or_w is a single CJK charater(i.e., "你"), do nothing.
if
pattern
.
fullmatch
(
ch_or_w
)
is
not
None
:
tokens
.
append
(
ch_or_w
)
# ch_or_w contains non-CJK charaters(i.e., " IT'S OKAY "),
# encode ch_or_w using bpe_model.
else
:
for
p
in
sp
.
encode_as_pieces
(
ch_or_w
):
tokens
.
append
(
p
)
return
tokens
def
tokenize
(
data
,
symbol_table
,
bpe_model
=
None
,
non_lang_syms
=
None
,
split_with_space
=
False
):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
if
non_lang_syms
is
not
None
:
non_lang_syms_pattern
=
re
.
compile
(
r
"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})"
)
else
:
non_lang_syms
=
{}
non_lang_syms_pattern
=
None
if
bpe_model
is
not
None
:
import
sentencepiece
as
spm
sp
=
spm
.
SentencePieceProcessor
()
sp
.
load
(
bpe_model
)
else
:
sp
=
None
for
sample
in
data
:
assert
'txt'
in
sample
txt
=
sample
[
'txt'
].
strip
()
if
non_lang_syms_pattern
is
not
None
:
parts
=
non_lang_syms_pattern
.
split
(
txt
.
upper
())
parts
=
[
w
for
w
in
parts
if
len
(
w
.
strip
())
>
0
]
else
:
parts
=
[
txt
]
label
=
[]
tokens
=
[]
for
part
in
parts
:
if
part
in
non_lang_syms
:
tokens
.
append
(
part
)
else
:
if
bpe_model
is
not
None
:
tokens
.
extend
(
__tokenize_by_bpe_model
(
sp
,
part
))
else
:
if
split_with_space
:
part
=
part
.
split
(
" "
)
for
ch
in
part
:
if
ch
==
' '
:
ch
=
"▁"
tokens
.
append
(
ch
)
for
ch
in
tokens
:
if
ch
in
symbol_table
:
label
.
append
(
symbol_table
[
ch
])
elif
'<unk>'
in
symbol_table
:
label
.
append
(
symbol_table
[
'<unk>'
])
sample
[
'tokens'
]
=
tokens
sample
[
'label'
]
=
label
yield
sample
def
spec_aug
(
data
,
num_t_mask
=
2
,
num_f_mask
=
2
,
max_t
=
50
,
max_f
=
10
,
max_w
=
80
):
""" Do spec augmentation
Inplace operation
Args:
data: Iterable[{key, feat, label}]
num_t_mask: number of time mask to apply
num_f_mask: number of freq mask to apply
max_t: max width of time mask
max_f: max width of freq mask
max_w: max width of time warp
Returns
Iterable[{key, feat, label}]
"""
for
sample
in
data
:
assert
'feat'
in
sample
x
=
sample
[
'feat'
]
assert
isinstance
(
x
,
torch
.
Tensor
)
y
=
x
.
clone
().
detach
()
max_frames
=
y
.
size
(
0
)
max_freq
=
y
.
size
(
1
)
# time mask
for
i
in
range
(
num_t_mask
):
start
=
random
.
randint
(
0
,
max_frames
-
1
)
length
=
random
.
randint
(
1
,
max_t
)
end
=
min
(
max_frames
,
start
+
length
)
y
[
start
:
end
,
:]
=
0
# freq mask
for
i
in
range
(
num_f_mask
):
start
=
random
.
randint
(
0
,
max_freq
-
1
)
length
=
random
.
randint
(
1
,
max_f
)
end
=
min
(
max_freq
,
start
+
length
)
y
[:,
start
:
end
]
=
0
sample
[
'feat'
]
=
y
yield
sample
def
spec_sub
(
data
,
max_t
=
20
,
num_t_sub
=
3
):
""" Do spec substitute
Inplace operation
Args:
data: Iterable[{key, feat, label}]
max_t: max width of time substitute
num_t_sub: number of time substitute to apply
Returns
Iterable[{key, feat, label}]
"""
for
sample
in
data
:
assert
'feat'
in
sample
x
=
sample
[
'feat'
]
assert
isinstance
(
x
,
torch
.
Tensor
)
y
=
x
.
clone
().
detach
()
max_frames
=
y
.
size
(
0
)
for
i
in
range
(
num_t_sub
):
start
=
random
.
randint
(
0
,
max_frames
-
1
)
length
=
random
.
randint
(
1
,
max_t
)
end
=
min
(
max_frames
,
start
+
length
)
# only substitute the earlier time chosen randomly for current time
pos
=
random
.
randint
(
0
,
start
)
y
[
start
:
end
,
:]
=
x
[
start
-
pos
:
end
-
pos
,
:]
sample
[
'feat'
]
=
y
yield
sample
def
spec_trim
(
data
,
max_t
=
20
):
""" Trim tailing frames. Inplace operation.
ref: TrimTail [https://arxiv.org/abs/2211.00522]
Args:
data: Iterable[{key, feat, label}]
max_t: max width of length trimming
Returns
Iterable[{key, feat, label}]
"""
for
sample
in
data
:
assert
'feat'
in
sample
x
=
sample
[
'feat'
]
assert
isinstance
(
x
,
torch
.
Tensor
)
max_frames
=
x
.
size
(
0
)
length
=
random
.
randint
(
1
,
max_t
)
if
length
<
max_frames
/
2
:
y
=
x
.
clone
().
detach
()[:
max_frames
-
length
]
sample
[
'feat'
]
=
y
yield
sample
def
shuffle
(
data
,
shuffle_size
=
10000
):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf
=
[]
for
sample
in
data
:
buf
.
append
(
sample
)
if
len
(
buf
)
>=
shuffle_size
:
random
.
shuffle
(
buf
)
for
x
in
buf
:
yield
x
buf
=
[]
# The sample left over
random
.
shuffle
(
buf
)
for
x
in
buf
:
yield
x
def
sort
(
data
,
sort_size
=
500
):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf
=
[]
for
sample
in
data
:
buf
.
append
(
sample
)
if
len
(
buf
)
>=
sort_size
:
buf
.
sort
(
key
=
lambda
x
:
x
[
'feat'
].
size
(
0
))
for
x
in
buf
:
yield
x
buf
=
[]
# The sample left over
buf
.
sort
(
key
=
lambda
x
:
x
[
'feat'
].
size
(
0
))
for
x
in
buf
:
yield
x
def
static_batch
(
data
,
batch_size
=
16
):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf
=
[]
for
sample
in
data
:
buf
.
append
(
sample
)
if
len
(
buf
)
>=
batch_size
:
yield
buf
buf
=
[]
if
len
(
buf
)
>
0
:
yield
buf
def
dynamic_batch
(
data
,
max_frames_in_batch
=
12000
):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf
=
[]
longest_frames
=
0
for
sample
in
data
:
assert
'feat'
in
sample
assert
isinstance
(
sample
[
'feat'
],
torch
.
Tensor
)
new_sample_frames
=
sample
[
'feat'
].
size
(
0
)
longest_frames
=
max
(
longest_frames
,
new_sample_frames
)
frames_after_padding
=
longest_frames
*
(
len
(
buf
)
+
1
)
if
frames_after_padding
>
max_frames_in_batch
:
yield
buf
buf
=
[
sample
]
longest_frames
=
new_sample_frames
else
:
buf
.
append
(
sample
)
if
len
(
buf
)
>
0
:
yield
buf
def
batch
(
data
,
batch_type
=
'static'
,
batch_size
=
16
,
max_frames_in_batch
=
12000
):
""" Wrapper for static/dynamic batch
"""
if
batch_type
==
'static'
:
return
static_batch
(
data
,
batch_size
)
elif
batch_type
==
'dynamic'
:
return
dynamic_batch
(
data
,
max_frames_in_batch
)
else
:
logging
.
fatal
(
'Unsupported batch type {}'
.
format
(
batch_type
))
def
padding
(
data
):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for
sample
in
data
:
assert
isinstance
(
sample
,
list
)
feats_length
=
torch
.
tensor
([
x
[
'feat'
].
size
(
0
)
for
x
in
sample
],
dtype
=
torch
.
int32
)
order
=
torch
.
argsort
(
feats_length
,
descending
=
True
)
feats_lengths
=
torch
.
tensor
(
[
sample
[
i
][
'feat'
].
size
(
0
)
for
i
in
order
],
dtype
=
torch
.
int32
)
sorted_feats
=
[
sample
[
i
][
'feat'
]
for
i
in
order
]
sorted_keys
=
[
sample
[
i
][
'key'
]
for
i
in
order
]
sorted_labels
=
[
torch
.
tensor
(
sample
[
i
][
'label'
],
dtype
=
torch
.
int64
)
for
i
in
order
]
label_lengths
=
torch
.
tensor
([
x
.
size
(
0
)
for
x
in
sorted_labels
],
dtype
=
torch
.
int32
)
padded_feats
=
pad_sequence
(
sorted_feats
,
batch_first
=
True
,
padding_value
=
0
)
padding_labels
=
pad_sequence
(
sorted_labels
,
batch_first
=
True
,
padding_value
=-
1
)
yield
(
sorted_keys
,
padded_feats
,
padding_labels
,
feats_lengths
,
label_lengths
)
examples/aishell/s0/wenet/dataset/wav_distortion.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc (Chao Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
random
import
math
import
torchaudio
import
torch
torchaudio
.
set_audio_backend
(
"sox_io"
)
def
db2amp
(
db
):
return
pow
(
10
,
db
/
20
)
def
amp2db
(
amp
):
return
20
*
math
.
log10
(
amp
)
def
make_poly_distortion
(
conf
):
"""Generate a db-domain ploynomial distortion function
f(x) = a * x^m * (1-x)^n + x
Args:
conf: a dict {'a': #int, 'm': #int, 'n': #int}
Returns:
The ploynomial function, which could be applied on
a float amplitude value
"""
a
=
conf
[
'a'
]
m
=
conf
[
'm'
]
n
=
conf
[
'n'
]
def
poly_distortion
(
x
):
abs_x
=
abs
(
x
)
if
abs_x
<
0.000001
:
x
=
x
else
:
db_norm
=
amp2db
(
abs_x
)
/
100
+
1
if
db_norm
<
0
:
db_norm
=
0
db_norm
=
a
*
pow
(
db_norm
,
m
)
*
pow
((
1
-
db_norm
),
n
)
+
db_norm
if
db_norm
>
1
:
db_norm
=
1
db
=
(
db_norm
-
1
)
*
100
amp
=
db2amp
(
db
)
if
amp
>=
0.9997
:
amp
=
0.9997
if
x
>
0
:
x
=
amp
else
:
x
=
-
amp
return
x
return
poly_distortion
def
make_quad_distortion
():
return
make_poly_distortion
({
'a'
:
1
,
'm'
:
1
,
'n'
:
1
})
# the amplitude are set to max for all non-zero point
def
make_max_distortion
(
conf
):
"""Generate a max distortion function
Args:
conf: a dict {'max_db': float }
'max_db': the maxium value.
Returns:
The max function, which could be applied on
a float amplitude value
"""
max_db
=
conf
[
'max_db'
]
if
max_db
:
max_amp
=
db2amp
(
max_db
)
# < 0.997
else
:
max_amp
=
0.997
def
max_distortion
(
x
):
if
x
>
0
:
x
=
max_amp
elif
x
<
0
:
x
=
-
max_amp
else
:
x
=
0.0
return
x
return
max_distortion
def
make_amp_mask
(
db_mask
=
None
):
"""Get a amplitude domain mask from db domain mask
Args:
db_mask: Optional. A list of tuple. if None, using default value.
Returns:
A list of tuple. The amplitude domain mask
"""
if
db_mask
is
None
:
db_mask
=
[(
-
110
,
-
95
),
(
-
90
,
-
80
),
(
-
65
,
-
60
),
(
-
50
,
-
30
),
(
-
15
,
0
)]
amp_mask
=
[(
db2amp
(
db
[
0
]),
db2amp
(
db
[
1
]))
for
db
in
db_mask
]
return
amp_mask
default_mask
=
make_amp_mask
()
def
generate_amp_mask
(
mask_num
):
"""Generate amplitude domain mask randomly in [-100db, 0db]
Args:
mask_num: the slot number of the mask
Returns:
A list of tuple. each tuple defines a slot.
e.g. [(-100, -80), (-65, -60), (-50, -30), (-15, 0)]
for #mask_num = 4
"""
a
=
[
0
]
*
2
*
mask_num
a
[
0
]
=
0
m
=
[]
for
i
in
range
(
1
,
2
*
mask_num
):
a
[
i
]
=
a
[
i
-
1
]
+
random
.
uniform
(
0.5
,
1
)
max_val
=
a
[
2
*
mask_num
-
1
]
for
i
in
range
(
0
,
mask_num
):
l
=
((
a
[
2
*
i
]
-
max_val
)
/
max_val
)
*
100
r
=
((
a
[
2
*
i
+
1
]
-
max_val
)
/
max_val
)
*
100
m
.
append
((
l
,
r
))
return
make_amp_mask
(
m
)
def
make_fence_distortion
(
conf
):
"""Generate a fence distortion function
In this fence-like shape function, the values in mask slots are
set to maxium, while the values not in mask slots are set to 0.
Use seperated masks for Positive and negetive amplitude.
Args:
conf: a dict {'mask_number': int,'max_db': float }
'mask_number': the slot number in mask.
'max_db': the maxium value.
Returns:
The fence function, which could be applied on
a float amplitude value
"""
mask_number
=
conf
[
'mask_number'
]
max_db
=
conf
[
'max_db'
]
max_amp
=
db2amp
(
max_db
)
# 0.997
if
mask_number
<=
0
:
positive_mask
=
default_mask
negative_mask
=
make_amp_mask
([(
-
50
,
0
)])
else
:
positive_mask
=
generate_amp_mask
(
mask_number
)
negative_mask
=
generate_amp_mask
(
mask_number
)
def
fence_distortion
(
x
):
is_in_mask
=
False
if
x
>
0
:
for
mask
in
positive_mask
:
if
x
>=
mask
[
0
]
and
x
<=
mask
[
1
]:
is_in_mask
=
True
return
max_amp
if
not
is_in_mask
:
return
0.0
elif
x
<
0
:
abs_x
=
abs
(
x
)
for
mask
in
negative_mask
:
if
abs_x
>=
mask
[
0
]
and
abs_x
<=
mask
[
1
]:
is_in_mask
=
True
return
max_amp
if
not
is_in_mask
:
return
0.0
return
x
return
fence_distortion
#
def
make_jag_distortion
(
conf
):
"""Generate a jag distortion function
In this jag-like shape function, the values in mask slots are
not changed, while the values not in mask slots are set to 0.
Use seperated masks for Positive and negetive amplitude.
Args:
conf: a dict {'mask_number': #int}
'mask_number': the slot number in mask.
Returns:
The jag function,which could be applied on
a float amplitude value
"""
mask_number
=
conf
[
'mask_number'
]
if
mask_number
<=
0
:
positive_mask
=
default_mask
negative_mask
=
make_amp_mask
([(
-
50
,
0
)])
else
:
positive_mask
=
generate_amp_mask
(
mask_number
)
negative_mask
=
generate_amp_mask
(
mask_number
)
def
jag_distortion
(
x
):
is_in_mask
=
False
if
x
>
0
:
for
mask
in
positive_mask
:
if
x
>=
mask
[
0
]
and
x
<=
mask
[
1
]:
is_in_mask
=
True
return
x
if
not
is_in_mask
:
return
0.0
elif
x
<
0
:
abs_x
=
abs
(
x
)
for
mask
in
negative_mask
:
if
abs_x
>=
mask
[
0
]
and
abs_x
<=
mask
[
1
]:
is_in_mask
=
True
return
x
if
not
is_in_mask
:
return
0.0
return
x
return
jag_distortion
# gaining 20db means amp = amp * 10
# gaining -20db means amp = amp / 10
def
make_gain_db
(
conf
):
"""Generate a db domain gain function
Args:
conf: a dict {'db': #float}
'db': the gaining value
Returns:
The db gain function, which could be applied on
a float amplitude value
"""
db
=
conf
[
'db'
]
def
gain_db
(
x
):
return
min
(
0.997
,
x
*
pow
(
10
,
db
/
20
))
return
gain_db
def
distort
(
x
,
func
,
rate
=
0.8
):
"""Distort a waveform in sample point level
Args:
x: the origin wavefrom
func: the distort function
rate: sample point-level distort probability
Returns:
the distorted waveform
"""
for
i
in
range
(
0
,
x
.
shape
[
1
]):
a
=
random
.
uniform
(
0
,
1
)
if
a
<
rate
:
x
[
0
][
i
]
=
func
(
float
(
x
[
0
][
i
]))
return
x
def
distort_chain
(
x
,
funcs
,
rate
=
0.8
):
for
i
in
range
(
0
,
x
.
shape
[
1
]):
a
=
random
.
uniform
(
0
,
1
)
if
a
<
rate
:
for
func
in
funcs
:
x
[
0
][
i
]
=
func
(
float
(
x
[
0
][
i
]))
return
x
# x is numpy
def
distort_wav_conf
(
x
,
distort_type
,
distort_conf
,
rate
=
0.1
):
if
distort_type
==
'gain_db'
:
gain_db
=
make_gain_db
(
distort_conf
)
x
=
distort
(
x
,
gain_db
)
elif
distort_type
==
'max_distortion'
:
max_distortion
=
make_max_distortion
(
distort_conf
)
x
=
distort
(
x
,
max_distortion
,
rate
=
rate
)
elif
distort_type
==
'fence_distortion'
:
fence_distortion
=
make_fence_distortion
(
distort_conf
)
x
=
distort
(
x
,
fence_distortion
,
rate
=
rate
)
elif
distort_type
==
'jag_distortion'
:
jag_distortion
=
make_jag_distortion
(
distort_conf
)
x
=
distort
(
x
,
jag_distortion
,
rate
=
rate
)
elif
distort_type
==
'poly_distortion'
:
poly_distortion
=
make_poly_distortion
(
distort_conf
)
x
=
distort
(
x
,
poly_distortion
,
rate
=
rate
)
elif
distort_type
==
'quad_distortion'
:
quad_distortion
=
make_quad_distortion
()
x
=
distort
(
x
,
quad_distortion
,
rate
=
rate
)
elif
distort_type
==
'none_distortion'
:
pass
else
:
print
(
'unsupport type'
)
return
x
def
distort_wav_conf_and_save
(
distort_type
,
distort_conf
,
rate
,
wav_in
,
wav_out
):
x
,
sr
=
torchaudio
.
load
(
wav_in
)
x
=
x
.
detach
().
numpy
()
out
=
distort_wav_conf
(
x
,
distort_type
,
distort_conf
,
rate
)
torchaudio
.
save
(
wav_out
,
torch
.
from_numpy
(
out
),
sr
)
if
__name__
==
"__main__"
:
distort_type
=
sys
.
argv
[
1
]
wav_in
=
sys
.
argv
[
2
]
wav_out
=
sys
.
argv
[
3
]
conf
=
None
rate
=
0.1
if
distort_type
==
'new_jag_distortion'
:
conf
=
{
'mask_number'
:
4
}
elif
distort_type
==
'new_fence_distortion'
:
conf
=
{
'mask_number'
:
1
,
'max_db'
:
-
30
}
elif
distort_type
==
'poly_distortion'
:
conf
=
{
'a'
:
4
,
'm'
:
2
,
"n"
:
2
}
distort_wav_conf_and_save
(
distort_type
,
conf
,
rate
,
wav_in
,
wav_out
)
examples/aishell/s0/wenet/efficient_conformer/__pycache__/attention.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/efficient_conformer/__pycache__/convolution.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/efficient_conformer/__pycache__/encoder.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment