Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
RapidASR
Commits
cdab2875
Commit
cdab2875
authored
Apr 07, 2023
by
SWHL
Browse files
Update files
parents
Pipeline
#335
failed with stages
in 0 seconds
Changes
38
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2765 additions
and
0 deletions
+2765
-0
.gitignore
.gitignore
+157
-0
README.md
README.md
+70
-0
deepspeech2/__init__.py
deepspeech2/__init__.py
+6
-0
deepspeech2/infer.py
deepspeech2/infer.py
+161
-0
deepspeech2/s2t/decoders/__init__.py
deepspeech2/s2t/decoders/__init__.py
+13
-0
deepspeech2/s2t/decoders/ctcdecoder/__init__.py
deepspeech2/s2t/decoders/ctcdecoder/__init__.py
+18
-0
deepspeech2/s2t/decoders/ctcdecoder/decoders_deprecated.py
deepspeech2/s2t/decoders/ctcdecoder/decoders_deprecated.py
+250
-0
deepspeech2/s2t/decoders/ctcdecoder/swig_wrapper.py
deepspeech2/s2t/decoders/ctcdecoder/swig_wrapper.py
+159
-0
deepspeech2/s2t/decoders/utils.py
deepspeech2/s2t/decoders/utils.py
+128
-0
deepspeech2/s2t/deepspeech2.py
deepspeech2/s2t/deepspeech2.py
+41
-0
deepspeech2/s2t/frontend/__init__.py
deepspeech2/s2t/frontend/__init__.py
+13
-0
deepspeech2/s2t/frontend/audio.py
deepspeech2/s2t/frontend/audio.py
+730
-0
deepspeech2/s2t/frontend/augmentor/__init__.py
deepspeech2/s2t/frontend/augmentor/__init__.py
+13
-0
deepspeech2/s2t/frontend/augmentor/augmentation.py
deepspeech2/s2t/frontend/augmentor/augmentation.py
+203
-0
deepspeech2/s2t/frontend/augmentor/base.py
deepspeech2/s2t/frontend/augmentor/base.py
+59
-0
deepspeech2/s2t/frontend/featurizer/__init__.py
deepspeech2/s2t/frontend/featurizer/__init__.py
+16
-0
deepspeech2/s2t/frontend/featurizer/audio_featurizer.py
deepspeech2/s2t/frontend/featurizer/audio_featurizer.py
+363
-0
deepspeech2/s2t/frontend/featurizer/speech_featurizer.py
deepspeech2/s2t/frontend/featurizer/speech_featurizer.py
+103
-0
deepspeech2/s2t/frontend/featurizer/text_featurizer.py
deepspeech2/s2t/frontend/featurizer/text_featurizer.py
+214
-0
deepspeech2/s2t/frontend/normalizer.py
deepspeech2/s2t/frontend/normalizer.py
+48
-0
No files found.
.gitignore
0 → 100644
View file @
cdab2875
*.pth
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.pytest_cache
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
# *.manifest
# *.spec
*.res
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
#idea
.vs
.vscode
.idea
#models
*.ttf
*.ttc
*.bin
*.mapping
*.xml
*.pdiparams
*.pdiparams.info
*.pdmodel
.DS_Store
\ No newline at end of file
README.md
0 → 100644
View file @
cdab2875
#### 基于PaddeSpeech训练所得模型的推理代码
-
项目来源:
[
PaddleSpeech/s2t
](
https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell/asr0
)
-
运行环境:Linux| Python 3.7 | CPU | 不依赖Paddle
#### 使用方法
1.
下载整个
`python/base_paddlespeech`
目录
2.
安装依赖环境
-
批量安装
```
bash
pip
install
-r
requirements.txt
-i
https://pypi.douban.com/simple/
# CentOS
sudo
yum
install
libsndfile
```
3.
下载
`resources`
模型相关文件到
`base_paddlespeech`
下,
-
下载
`resources`
链接:
[
Google Drive
](
https://drive.google.com/file/d/1MWmKxsfCNQyQ5CPlaYxJKnYfIIC5OO5L/view?usp=sharing
)
-
下载语言模型文件→
[
下载链接
](
https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm
)
,放到
`base_paddlespeech/resources/models/language_model`
目录下
-
最终结构目录如下,请自行比对:
```
text
base_paddlespeech
├── deepspeech2
│ ├── infer.py
│ ├── __init__.py
│ └── s2t
│ ├── decoders
│ ├── deepspeech2.py
│ ├── frontend
│ ├── io
│ ├── modules
│ ├── __pycache__
│ ├── transform
│ └── utils
├── main.py
├── requirements.txt
├── resources
│ └── models
│ ├── asr0_deepspeech2_online_aishell_ckpt_0.2.0.onnx
│ ├── language_model
│ │ └── zh_giga.no_cna_cmn.prune01244.klm
│ └── model.yaml
└── test_wav
└── zh.wav
```
4.
运行
`python main.py`
5.
运行结果如下:
```
text
checking the audio file format......
The sample rate is 16000
The audio file format is right
Preprocess audio_file:/da2/SWHL/test_wav/zh.wav
audio feat shape: (1, 498, 161)
ASR Result: 我认为跑步最重要的就是给我们带来了身体健康
```
#### 模型转onnx代码
```
bash
model_dir
=
"pretrained_models/deepspeech2online_aishell-zh-16k/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar/exp/deepspeech2_online/checkpoints"
pdmodel
=
"avg_1.jit.pdmodel"
params_file
=
"avg_1.jit.pdiparams"
save_onnx
=
"pretrained_models/onnx/asr0_deepspeech2_online_aishell_ckpt_0.1.1.onnx"
paddle2onnx
--model_dir
${
model_dir
}
\
--model_filename
${
pdmodel
}
\
--params_filename
${
params_file
}
\
--save_file
${
save_onnx
}
\
--opset_version
12
```
deepspeech2/__init__.py
0 → 100644
View file @
cdab2875
# !/usr/bin/env python
# -*- encoding: utf-8 -*-
# @File: __init__.py
# @Author: SWHL
# @Contact: liekkaskono@163.com
from
.infer
import
ASRExecutor
deepspeech2/infer.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
from
collections
import
OrderedDict
from
typing
import
Union
import
numpy
as
np
import
soundfile
from
yacs.config
import
CfgNode
from
.s2t.deepspeech2
import
DeepSpeech2ModelOnline
from
.s2t.frontend.featurizer.text_featurizer
import
TextFeaturizer
from
.s2t.io.collator
import
SpeechCollator
from
.s2t.utils.utility
import
UpdateConfig
class
ASRExecutor
(
object
):
def
__init__
(
self
,
sample_rate
:
int
=
16000
,
config_path
:
os
.
PathLike
=
None
,
onnx_path
:
os
.
PathLike
=
None
,
decode_method
:
str
=
'attention_rescoring'
,
lan_model_path
=
None
):
self
.
sample_rate
=
sample_rate
self
.
config_path
=
config_path
self
.
onnx_path
=
onnx_path
self
.
decode_method
=
decode_method
self
.
lan_model_path
=
lan_model_path
self
.
_inputs
=
OrderedDict
()
self
.
_outputs
=
OrderedDict
()
self
.
config_path
=
os
.
path
.
abspath
(
self
.
config_path
)
self
.
res_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
self
.
config_path
)))
self
.
config
=
CfgNode
(
new_allowed
=
True
)
self
.
config
.
merge_from_file
(
self
.
config_path
)
with
UpdateConfig
(
self
.
config
):
self
.
vocab
=
self
.
config
.
vocab_filepath
self
.
config
.
decode
.
lang_model_path
=
self
.
lan_model_path
self
.
collate_fn_test
=
SpeechCollator
.
from_config
(
self
.
config
)
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
self
.
config
.
unit_type
,
vocab
=
self
.
vocab
)
self
.
model
=
DeepSpeech2ModelOnline
(
encoder_onnx_path
=
self
.
onnx_path
)
def
__call__
(
self
,
audio_file
,
force_yes
:
bool
=
False
):
audio_file
=
os
.
path
.
abspath
(
audio_file
)
if
not
self
.
_check
(
audio_file
,
self
.
sample_rate
,
force_yes
):
sys
.
exit
(
-
1
)
self
.
preprocess
(
audio_file
)
res
=
self
.
infer
()
return
res
def
preprocess
(
self
,
input
:
Union
[
str
,
os
.
PathLike
]):
audio_file
=
input
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
print
(
"Preprocess audio_file:"
+
audio_file
)
# Get the object for feature extraction
audio
,
_
=
self
.
collate_fn_test
.
process_utterance
(
audio_file
=
audio_file
,
transcript
=
" "
)
audio_len
=
audio
.
shape
[
0
]
audio
=
audio
[
np
.
newaxis
,
...]
self
.
_inputs
[
"audio"
]
=
audio
self
.
_inputs
[
"audio_len"
]
=
audio_len
print
(
f
"audio feat shape:
{
audio
.
shape
}
"
)
def
infer
(
self
):
"""
Model inference and result stored in self.output.
"""
cfg
=
self
.
config
.
decode
audio
=
self
.
_inputs
[
"audio"
]
audio_len
=
self
.
_inputs
[
"audio_len"
]
decode_batch_size
=
audio
.
shape
[
0
]
self
.
model
.
decoder
.
init_decoder
(
decode_batch_size
,
self
.
text_feature
.
vocab_list
,
cfg
.
decoding_method
,
cfg
.
lang_model_path
,
cfg
.
alpha
,
cfg
.
beta
,
cfg
.
beam_size
,
cfg
.
cutoff_prob
,
cfg
.
cutoff_top_n
,
cfg
.
num_proc_bsearch
)
result_transcripts
=
self
.
model
.
decode
(
audio
,
audio_len
)
self
.
model
.
decoder
.
del_decoder
()
return
result_transcripts
[
0
]
def
_check
(
self
,
audio_file
:
str
,
sample_rate
:
int
,
force_yes
:
bool
):
self
.
sample_rate
=
sample_rate
if
self
.
sample_rate
!=
16000
and
self
.
sample_rate
!=
8000
:
print
(
"invalid sample rate, please input --sr 8000 or --sr 16000"
)
return
False
if
isinstance
(
audio_file
,
(
str
,
os
.
PathLike
)):
if
not
os
.
path
.
isfile
(
audio_file
):
print
(
"Please input the right audio file path"
)
return
False
print
(
"checking the audio file format......"
)
try
:
audio
,
audio_sample_rate
=
soundfile
.
read
(
audio_file
,
dtype
=
"int16"
,
always_2d
=
True
)
except
Exception
as
e
:
print
(
"can not open the audio file, please check the audio file format is 'wav'.
\n
\
you can try to use sox to change the file format.
\n
\
For example:
\n
\
sample rate: 16k
\n
\
sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav
\n
\
sample rate: 8k
\n
\
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav
\n
\
"
)
return
False
print
(
"The sample rate is %d"
%
audio_sample_rate
)
if
audio_sample_rate
!=
self
.
sample_rate
:
print
(
"The sample rate of the input file is not {}.
\n
\
The program will resample the wav file to {}.
\n
\
If the result does not meet your expectations,
\n
\
Please input the 16k 16 bit 1 channel wav file.
\
"
.
format
(
self
.
sample_rate
,
self
.
sample_rate
))
if
force_yes
is
False
:
while
(
True
):
print
(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
content
=
input
(
"Input(Y/N):"
)
if
content
.
strip
()
in
[
"Y"
,
"y"
,
"yes"
,
"Yes"
]:
print
(
"change the sampele rate, channel to 16k and 1 channel"
)
break
elif
content
.
strip
()
in
[
"N"
,
"n"
,
"no"
,
"No"
]:
print
(
"Exit the program"
)
exit
(
1
)
else
:
print
(
"Not regular input, please input again"
)
self
.
change_format
=
True
else
:
print
(
"The audio file format is right"
)
self
.
change_format
=
False
return
True
deepspeech2/s2t/decoders/__init__.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
deepspeech2/s2t/decoders/ctcdecoder/__init__.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.swig_wrapper
import
ctc_beam_search_decoding
from
.swig_wrapper
import
ctc_beam_search_decoding_batch
from
.swig_wrapper
import
ctc_greedy_decoding
from
.swig_wrapper
import
CTCBeamSearchDecoder
from
.swig_wrapper
import
Scorer
deepspeech2/s2t/decoders/ctcdecoder/decoders_deprecated.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
import
multiprocessing
from
itertools
import
groupby
from
math
import
log
import
numpy
as
np
def
ctc_greedy_decoder
(
probs_seq
,
vocabulary
):
"""CTC greedy (best path) decoder.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: baseline
"""
# dimension verification
for
probs
in
probs_seq
:
if
not
len
(
probs
)
==
len
(
vocabulary
)
+
1
:
raise
ValueError
(
"probs_seq dimension mismatchedd with vocabulary"
)
# argmax to get the best index for each time step
max_index_list
=
list
(
np
.
array
(
probs_seq
).
argmax
(
axis
=
1
))
# remove consecutive duplicate indexes
index_list
=
[
index_group
[
0
]
for
index_group
in
groupby
(
max_index_list
)]
# remove blank indexes
blank_index
=
len
(
vocabulary
)
index_list
=
[
index
for
index
in
index_list
if
index
!=
blank_index
]
# convert index list to string
return
''
.
join
([
vocabulary
[
index
]
for
index
in
index_list
])
def
ctc_beam_search_decoder
(
probs_seq
,
beam_size
,
vocabulary
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
,
nproc
=
False
):
"""CTC Beam search decoder.
It utilizes beam search to approximately select top best decoding
labels and returning results in the descending order.
The implementation is based on Prefix Beam Search
(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned. Two important modifications: 1) in the iterative computation
of probabilities, the assignment operation is changed to accumulation for
one prefix may comes from different paths; 2) the if condition "if l^+ not
in A_prev then" after probabilities' computation is deprecated for it is
hard to understand and seems unnecessary.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
# dimension check
for
prob_list
in
probs_seq
:
if
not
len
(
prob_list
)
==
len
(
vocabulary
)
+
1
:
raise
ValueError
(
"The shape of prob_seq does not match with the "
"shape of the vocabulary."
)
# blank_id assign
blank_id
=
len
(
vocabulary
)
# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_batch().
if
nproc
is
True
:
global
ext_nproc_scorer
ext_scoring_func
=
ext_nproc_scorer
# initialize
# prefix_set_prev: the set containing selected prefixes
# probs_b_prev: prefixes' probability ending with blank in previous step
# probs_nb_prev: prefixes' probability ending with non-blank in previous
# step
prefix_set_prev
=
{
'
\t
'
:
1.0
}
probs_b_prev
,
probs_nb_prev
=
{
'
\t
'
:
1.0
},
{
'
\t
'
:
0.0
}
# extend prefix in loop
for
time_step
in
range
(
len
(
probs_seq
)):
# prefix_set_next: the set containing candidate prefixes
# probs_b_cur: prefixes' probability ending with blank in current step
# probs_nb_cur: prefixes' probability ending with non-blank in current
# step
prefix_set_next
,
probs_b_cur
,
probs_nb_cur
=
{},
{},
{}
prob_idx
=
list
(
enumerate
(
probs_seq
[
time_step
]))
cutoff_len
=
len
(
prob_idx
)
# If pruning is enabled
if
cutoff_prob
<
1.0
or
cutoff_top_n
<
cutoff_len
:
prob_idx
=
sorted
(
prob_idx
,
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
cutoff_len
,
cum_prob
=
0
,
0.0
for
i
in
range
(
len
(
prob_idx
)):
cum_prob
+=
prob_idx
[
i
][
1
]
cutoff_len
+=
1
if
cum_prob
>=
cutoff_prob
:
break
cutoff_len
=
min
(
cutoff_len
,
cutoff_top_n
)
prob_idx
=
prob_idx
[
0
:
cutoff_len
]
for
l
in
prefix_set_prev
:
if
l
not
in
prefix_set_next
:
probs_b_cur
[
l
],
probs_nb_cur
[
l
]
=
0.0
,
0.0
# extend prefix by travering prob_idx
for
index
in
range
(
cutoff_len
):
c
,
prob_c
=
prob_idx
[
index
][
0
],
prob_idx
[
index
][
1
]
if
c
==
blank_id
:
probs_b_cur
[
l
]
+=
prob_c
*
(
probs_b_prev
[
l
]
+
probs_nb_prev
[
l
])
else
:
last_char
=
l
[
-
1
]
new_char
=
vocabulary
[
c
]
l_plus
=
l
+
new_char
if
l_plus
not
in
prefix_set_next
:
probs_b_cur
[
l_plus
],
probs_nb_cur
[
l_plus
]
=
0.0
,
0.0
if
new_char
==
last_char
:
probs_nb_cur
[
l_plus
]
+=
prob_c
*
probs_b_prev
[
l
]
probs_nb_cur
[
l
]
+=
prob_c
*
probs_nb_prev
[
l
]
elif
new_char
==
' '
:
if
(
ext_scoring_func
is
None
)
or
(
len
(
l
)
==
1
):
score
=
1.0
else
:
prefix
=
l
[
1
:]
score
=
ext_scoring_func
(
prefix
)
probs_nb_cur
[
l_plus
]
+=
score
*
prob_c
*
(
probs_b_prev
[
l
]
+
probs_nb_prev
[
l
])
else
:
probs_nb_cur
[
l_plus
]
+=
prob_c
*
(
probs_b_prev
[
l
]
+
probs_nb_prev
[
l
])
# add l_plus into prefix_set_next
prefix_set_next
[
l_plus
]
=
probs_nb_cur
[
l_plus
]
+
probs_b_cur
[
l_plus
]
# add l into prefix_set_next
prefix_set_next
[
l
]
=
probs_b_cur
[
l
]
+
probs_nb_cur
[
l
]
# update probs
probs_b_prev
,
probs_nb_prev
=
probs_b_cur
,
probs_nb_cur
# store top beam_size prefixes
prefix_set_prev
=
sorted
(
prefix_set_next
.
items
(),
key
=
lambda
asd
:
asd
[
1
],
reverse
=
True
)
if
beam_size
<
len
(
prefix_set_prev
):
prefix_set_prev
=
prefix_set_prev
[:
beam_size
]
prefix_set_prev
=
dict
(
prefix_set_prev
)
beam_result
=
[]
for
seq
,
prob
in
prefix_set_prev
.
items
():
if
prob
>
0.0
and
len
(
seq
)
>
1
:
result
=
seq
[
1
:]
# score last word by external scorer
if
(
ext_scoring_func
is
not
None
)
and
(
result
[
-
1
]
!=
' '
):
prob
=
prob
*
ext_scoring_func
(
result
)
log_prob
=
log
(
prob
)
beam_result
.
append
((
log_prob
,
result
))
else
:
beam_result
.
append
((
float
(
'-inf'
),
''
))
# output top beam_size decoding results
beam_result
=
sorted
(
beam_result
,
key
=
lambda
asd
:
asd
[
0
],
reverse
=
True
)
return
beam_result
def
ctc_beam_search_decoder_batch
(
probs_split
,
beam_size
,
vocabulary
,
num_processes
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
):
"""CTC beam search decoder using multiple processes.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param num_processes: Number of parallel processes.
:type num_processes: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
if
not
num_processes
>
0
:
raise
ValueError
(
"Number of processes must be positive!"
)
# use global variable to pass the externnal scorer to beam search decoder
global
ext_nproc_scorer
ext_nproc_scorer
=
ext_scoring_func
nproc
=
True
pool
=
multiprocessing
.
Pool
(
processes
=
num_processes
)
results
=
[]
for
i
,
probs_list
in
enumerate
(
probs_split
):
args
=
(
probs_list
,
beam_size
,
vocabulary
,
cutoff_prob
,
cutoff_top_n
,
None
,
nproc
)
results
.
append
(
pool
.
apply_async
(
ctc_beam_search_decoder
,
args
))
pool
.
close
()
pool
.
join
()
beam_search_results
=
[
result
.
get
()
for
result
in
results
]
return
beam_search_results
deepspeech2/s2t/decoders/ctcdecoder/swig_wrapper.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper for various CTC decoders in SWIG."""
import
paddlespeech_ctcdecoders
class
Scorer
(
paddlespeech_ctcdecoders
.
Scorer
):
"""Wrapper for Scorer.
:param alpha: Parameter associated with language model. Don't use
language model when alpha = 0.
:type alpha: float
:param beta: Parameter associated with word count. Don't use word
count when beta = 0.
:type beta: float
:model_path: Path to load language model.
:type model_path: str
:param vocabulary: Vocabulary list.
:type vocabulary: list
"""
def
__init__
(
self
,
alpha
,
beta
,
model_path
,
vocabulary
):
paddlespeech_ctcdecoders
.
Scorer
.
__init__
(
self
,
alpha
,
beta
,
model_path
,
vocabulary
)
def
ctc_greedy_decoding
(
probs_seq
,
vocabulary
,
blank_id
):
"""Wrapper for ctc best path decodeing function in swig.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: str
"""
result
=
paddlespeech_ctcdecoders
.
ctc_greedy_decoding
(
probs_seq
.
tolist
(),
vocabulary
,
blank_id
)
return
result
def
ctc_beam_search_decoding
(
probs_seq
,
vocabulary
,
beam_size
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
,
blank_id
=
0
):
"""Wrapper for the CTC Beam Search Decoding function.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
beam_results
=
paddlespeech_ctcdecoders
.
ctc_beam_search_decoding
(
probs_seq
.
tolist
(),
vocabulary
,
beam_size
,
cutoff_prob
,
cutoff_top_n
,
ext_scoring_func
,
blank_id
)
beam_results
=
[(
res
[
0
],
res
[
1
].
decode
(
'utf-8'
))
for
res
in
beam_results
]
return
beam_results
def
ctc_beam_search_decoding_batch
(
probs_split
,
vocabulary
,
beam_size
,
num_processes
,
cutoff_prob
=
1.0
,
cutoff_top_n
=
40
,
ext_scoring_func
=
None
,
blank_id
=
0
):
"""Wrapper for the batched CTC beam search decodeing batch function.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in vocabulary pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
probs_split
=
[
probs_seq
.
tolist
()
for
probs_seq
in
probs_split
]
batch_beam_results
=
paddlespeech_ctcdecoders
.
ctc_beam_search_decoding_batch
(
probs_split
,
vocabulary
,
beam_size
,
num_processes
,
cutoff_prob
,
cutoff_top_n
,
ext_scoring_func
,
blank_id
)
batch_beam_results
=
[[(
res
[
0
],
res
[
1
])
for
res
in
beam_results
]
for
beam_results
in
batch_beam_results
]
return
batch_beam_results
class
CTCBeamSearchDecoder
(
paddlespeech_ctcdecoders
.
CtcBeamSearchDecoderBatch
):
"""Wrapper for CtcBeamSearchDecoderBatch.
Args:
vocab_list (list): Vocabulary list.
beam_size (int): Width for beam search.
num_processes (int): Number of parallel processes.
param cutoff_prob (float): Cutoff probability in vocabulary pruning,
default 1.0, no pruning.
cutoff_top_n (int): Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
param ext_scorer (Scorer): External scorer for partially decoded sentence, e.g. word count
or language model.
"""
def
__init__
(
self
,
vocab_list
,
batch_size
,
beam_size
,
num_processes
,
cutoff_prob
,
cutoff_top_n
,
_ext_scorer
,
blank_id
):
paddlespeech_ctcdecoders
.
CtcBeamSearchDecoderBatch
.
__init__
(
self
,
vocab_list
,
batch_size
,
beam_size
,
num_processes
,
cutoff_prob
,
cutoff_top_n
,
_ext_scorer
,
blank_id
)
deepspeech2/s2t/decoders/utils.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import
numpy
as
np
__all__
=
[
"end_detect"
,
"parse_hypothesis"
,
"add_results_to_json"
]
def
end_detect
(
ended_hyps
,
i
,
M
=
3
,
D_end
=
np
.
log
(
1
*
np
.
exp
(
-
10
))):
"""End detection.
described in Eq. (50) of S. Watanabe et al
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
:param ended_hyps: dict
:param i: int
:param M: int
:param D_end: float
:return: bool
"""
if
len
(
ended_hyps
)
==
0
:
return
False
count
=
0
best_hyp
=
sorted
(
ended_hyps
,
key
=
lambda
x
:
x
[
"score"
],
reverse
=
True
)[
0
]
for
m
in
range
(
M
):
# get ended_hyps with their length is i - m
hyp_length
=
i
-
m
hyps_same_length
=
[
x
for
x
in
ended_hyps
if
len
(
x
[
"yseq"
])
==
hyp_length
]
if
len
(
hyps_same_length
)
>
0
:
best_hyp_same_length
=
sorted
(
hyps_same_length
,
key
=
lambda
x
:
x
[
"score"
],
reverse
=
True
)[
0
]
if
best_hyp_same_length
[
"score"
]
-
best_hyp
[
"score"
]
<
D_end
:
count
+=
1
if
count
==
M
:
return
True
else
:
return
False
# * ------------------ recognition related ------------------ *
def
parse_hypothesis
(
hyp
,
char_list
):
"""Parse hypothesis.
Args:
hyp (list[dict[str, Any]]): Recognition hypothesis.
char_list (list[str]): List of characters.
Returns:
tuple(str, str, str, float)
"""
# remove sos and get results
tokenid_as_list
=
list
(
map
(
int
,
hyp
[
"yseq"
][
1
:]))
token_as_list
=
[
char_list
[
idx
]
for
idx
in
tokenid_as_list
]
score
=
float
(
hyp
[
"score"
])
# convert to string
tokenid
=
" "
.
join
([
str
(
idx
)
for
idx
in
tokenid_as_list
])
token
=
" "
.
join
(
token_as_list
)
text
=
""
.
join
(
token_as_list
).
replace
(
"<space>"
,
" "
)
return
text
,
token
,
tokenid
,
score
def
add_results_to_json
(
js
,
nbest_hyps
,
char_list
):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
nbest_hyps_sd (list[dict[str, Any]]):
List of hypothesis for multi_speakers: nutts x nspkrs.
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js
=
dict
()
new_js
[
"utt2spk"
]
=
js
[
"utt2spk"
]
new_js
[
"output"
]
=
[]
for
n
,
hyp
in
enumerate
(
nbest_hyps
,
1
):
# parse hypothesis
rec_text
,
rec_token
,
rec_tokenid
,
score
=
parse_hypothesis
(
hyp
,
char_list
)
# copy ground-truth
if
len
(
js
[
"output"
])
>
0
:
out_dic
=
dict
(
js
[
"output"
][
0
].
items
())
else
:
# for no reference case (e.g., speech translation)
out_dic
=
{
"name"
:
""
}
# update name
out_dic
[
"name"
]
+=
"[%d]"
%
n
# add recognition results
out_dic
[
"rec_text"
]
=
rec_text
out_dic
[
"rec_token"
]
=
rec_token
out_dic
[
"rec_tokenid"
]
=
rec_tokenid
out_dic
[
"score"
]
=
score
# add to list of N-best result dicts
new_js
[
"output"
].
append
(
out_dic
)
# show 1-best result
if
n
==
1
:
if
"text"
in
out_dic
.
keys
():
print
(
"groundtruth: %s"
%
out_dic
[
"text"
])
print
(
"prediction : %s"
%
out_dic
[
"rec_text"
])
return
new_js
deepspeech2/s2t/deepspeech2.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Online Model"""
import
numpy
as
np
import
onnxruntime
as
ort
from
.modules.ctc
import
CTCDecoder
class
DeepSpeech2ModelOnline
(
object
):
def
__init__
(
self
,
encoder_onnx_path
):
self
.
encoder_sess
=
ort
.
InferenceSession
(
encoder_onnx_path
)
self
.
decoder
=
CTCDecoder
()
def
decode
(
self
,
audio
,
audio_len
):
onnx_inputs_name
=
self
.
encoder_sess
.
get_inputs
()
ort_inputs
=
{
onnx_inputs_name
[
0
].
name
:
np
.
array
(
audio
).
astype
(
np
.
float32
),
onnx_inputs_name
[
1
].
name
:
np
.
array
([
audio_len
]).
astype
(
np
.
int64
),
onnx_inputs_name
[
2
].
name
:
np
.
zeros
([
5
,
1
,
1024
]).
astype
(
np
.
float32
),
onnx_inputs_name
[
3
].
name
:
np
.
zeros
([
5
,
1
,
1024
]).
astype
(
np
.
float32
)
}
ort_outputs
=
self
.
encoder_sess
.
run
(
None
,
ort_inputs
)
probs
,
eouts_len
,
_
,
_
=
ort_outputs
batch_size
=
probs
.
shape
[
0
]
self
.
decoder
.
reset_decoder
(
batch_size
=
batch_size
)
self
.
decoder
.
next
(
probs
,
eouts_len
)
trans_best
,
trans_beam
=
self
.
decoder
.
decode
()
return
trans_best
deepspeech2/s2t/frontend/__init__.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
deepspeech2/s2t/frontend/audio.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the audio segment class."""
import
copy
import
io
import
random
import
re
import
struct
import
numpy
as
np
import
resampy
import
soundfile
from
scipy
import
signal
from
.utility
import
convert_samples_from_float32
from
.utility
import
convert_samples_to_float32
from
.utility
import
subfile_from_tar
class
AudioSegment
():
"""Monaural audio segment abstraction.
:param samples: Audio samples [num_samples x num_channels].
:type samples: ndarray.float32
:param sample_rate: Audio sample rate.
:type sample_rate: int
:raises TypeError: If the sample data type is not float or int.
"""
def
__init__
(
self
,
samples
,
sample_rate
):
"""Create audio segment from samples.
Samples are convert float32 internally, with int scaled to [-1, 1].
"""
self
.
_samples
=
self
.
_convert_samples_to_float32
(
samples
)
self
.
_sample_rate
=
sample_rate
if
self
.
_samples
.
ndim
>=
2
:
self
.
_samples
=
np
.
mean
(
self
.
_samples
,
1
)
def
__eq__
(
self
,
other
):
"""Return whether two objects are equal."""
if
not
isinstance
(
other
,
type
(
self
)):
return
False
if
self
.
_sample_rate
!=
other
.
_sample_rate
:
return
False
if
self
.
_samples
.
shape
!=
other
.
_samples
.
shape
:
return
False
if
np
.
any
(
self
.
samples
!=
other
.
_samples
):
return
False
return
True
def
__ne__
(
self
,
other
):
"""Return whether two objects are unequal."""
return
not
self
.
__eq__
(
other
)
def
__str__
(
self
):
"""Return human-readable representation of segment."""
return
(
"%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
"rms=%.2fdB"
%
(
type
(
self
),
self
.
num_samples
,
self
.
sample_rate
,
self
.
duration
,
self
.
rms_db
))
@
classmethod
def
from_file
(
cls
,
file
,
infos
=
None
):
"""Create audio segment from audio file.
Args:
filepath (str|file): Filepath or file object to audio file.
infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None.
Returns:
AudioSegment: Audio segment instance.
"""
if
isinstance
(
file
,
str
)
and
re
.
findall
(
r
".seqbin_\d+$"
,
file
):
return
cls
.
from_sequence_file
(
file
)
elif
isinstance
(
file
,
str
)
and
file
.
startswith
(
'tar:'
):
return
cls
.
from_file
(
subfile_from_tar
(
file
,
infos
))
else
:
samples
,
sample_rate
=
soundfile
.
read
(
file
,
dtype
=
'float32'
)
return
cls
(
samples
,
sample_rate
)
@
classmethod
def
slice_from_file
(
cls
,
file
,
start
=
None
,
end
=
None
):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: str|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile
=
soundfile
.
SoundFile
(
file
)
sample_rate
=
sndfile
.
samplerate
duration
=
float
(
len
(
sndfile
))
/
sample_rate
start
=
0.
if
start
is
None
else
start
end
=
duration
if
end
is
None
else
end
if
start
<
0.0
:
start
+=
duration
if
end
<
0.0
:
end
+=
duration
if
start
<
0.0
:
raise
ValueError
(
"The slice start position (%f s) is out of "
"bounds."
%
start
)
if
end
<
0.0
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds."
%
end
)
if
start
>
end
:
raise
ValueError
(
"The slice start position (%f s) is later than "
"the slice end position (%f s)."
%
(
start
,
end
))
if
end
>
duration
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds "
"(> %f s)"
%
(
end
,
duration
))
start_frame
=
int
(
start
*
sample_rate
)
end_frame
=
int
(
end
*
sample_rate
)
sndfile
.
seek
(
start_frame
)
data
=
sndfile
.
read
(
frames
=
end_frame
-
start_frame
,
dtype
=
'float32'
)
return
cls
(
data
,
sample_rate
)
@
classmethod
def
from_sequence_file
(
cls
,
filepath
):
"""Create audio segment from sequence file. Sequence file is a binary
file containing a collection of multiple audio files, with several
header bytes in the head indicating the offsets of each audio byte data
chunk.
The format is:
4 bytes (int, version),
4 bytes (int, num of utterance),
4 bytes (int, bytes per header),
[bytes_per_header*(num_utterance+1)] bytes (offsets for each audio),
audio_bytes_data_of_1st_utterance,
audio_bytes_data_of_2nd_utterance,
......
Sequence file name must end with ".seqbin". And the filename of the 5th
utterance's audio file in sequence file "xxx.seqbin" must be
"xxx.seqbin_5", with "5" indicating the utterance index within this
sequence file (starting from 1).
:param filepath: Filepath of sequence file.
:type filepath: str
:return: Audio segment instance.
:rtype: AudioSegment
"""
# parse filepath
matches
=
re
.
match
(
r
"(.+\.seqbin)_(\d+)"
,
filepath
)
if
matches
is
None
:
raise
IOError
(
"File type of %s is not supported"
%
filepath
)
filename
=
matches
.
group
(
1
)
fileno
=
int
(
matches
.
group
(
2
))
# read headers
f
=
io
.
open
(
filename
,
mode
=
'rb'
,
encoding
=
'utf8'
)
version
=
f
.
read
(
4
)
num_utterances
=
struct
.
unpack
(
"i"
,
f
.
read
(
4
))[
0
]
bytes_per_header
=
struct
.
unpack
(
"i"
,
f
.
read
(
4
))[
0
]
header_bytes
=
f
.
read
(
bytes_per_header
*
(
num_utterances
+
1
))
header
=
[
struct
.
unpack
(
"i"
,
header_bytes
[
bytes_per_header
*
i
:
bytes_per_header
*
(
i
+
1
)])[
0
]
for
i
in
range
(
num_utterances
+
1
)
]
# read audio bytes
f
.
seek
(
header
[
fileno
-
1
])
audio_bytes
=
f
.
read
(
header
[
fileno
]
-
header
[
fileno
-
1
])
f
.
close
()
# create audio segment
try
:
return
cls
.
from_bytes
(
audio_bytes
)
except
Exception
as
e
:
samples
=
np
.
frombuffer
(
audio_bytes
,
dtype
=
'int16'
)
return
cls
(
samples
=
samples
,
sample_rate
=
8000
)
@
classmethod
def
from_bytes
(
cls
,
bytes
):
"""Create audio segment from a byte string containing audio samples.
:param bytes: Byte string containing audio samples.
:type bytes: str
:return: Audio segment instance.
:rtype: AudioSegment
"""
samples
,
sample_rate
=
soundfile
.
read
(
io
.
BytesIO
(
bytes
),
dtype
=
'float32'
)
return
cls
(
samples
,
sample_rate
)
@
classmethod
def
from_pcm
(
cls
,
samples
,
sample_rate
):
"""Create audio segment from a byte string containing audio samples.
:param samples: Audio samples [num_samples x num_channels].
:type samples: numpy.ndarray
:param sample_rate: Audio sample rate.
:type sample_rate: int
:return: Audio segment instance.
:rtype: AudioSegment
"""
return
cls
(
samples
,
sample_rate
)
@
classmethod
def
concatenate
(
cls
,
*
segments
):
"""Concatenate an arbitrary number of audio segments together.
:param *segments: Input audio segments to be concatenated.
:type *segments: tuple of AudioSegment
:return: Audio segment instance as concatenating results.
:rtype: AudioSegment
:raises ValueError: If the number of segments is zero, or if the
sample_rate of any segments does not match.
:raises TypeError: If any segment is not AudioSegment instance.
"""
# Perform basic sanity-checks.
if
len
(
segments
)
==
0
:
raise
ValueError
(
"No audio segments are given to concatenate."
)
sample_rate
=
segments
[
0
].
_sample_rate
for
seg
in
segments
:
if
sample_rate
!=
seg
.
_sample_rate
:
raise
ValueError
(
"Can't concatenate segments with "
"different sample rates"
)
if
not
isinstance
(
seg
,
cls
):
raise
TypeError
(
"Only audio segments of the same type "
"can be concatenated."
)
samples
=
np
.
concatenate
([
seg
.
samples
for
seg
in
segments
])
return
cls
(
samples
,
sample_rate
)
@
classmethod
def
make_silence
(
cls
,
duration
,
sample_rate
):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples
=
np
.
zeros
(
int
(
duration
*
sample_rate
))
return
cls
(
samples
,
sample_rate
)
def
to_wav_file
(
self
,
filepath
,
dtype
=
'float32'
):
"""Save audio segment to disk as wav file.
:param filepath: WAV filepath or file object to save the
audio segment.
:type filepath: str|file
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
'float32', 'float64'. Default is 'float32'.
:type dtype: str
:raises TypeError: If dtype is not supported.
"""
samples
=
self
.
_convert_samples_from_float32
(
self
.
_samples
,
dtype
)
subtype_map
=
{
'int16'
:
'PCM_16'
,
'int32'
:
'PCM_32'
,
'float32'
:
'FLOAT'
,
'float64'
:
'DOUBLE'
}
soundfile
.
write
(
filepath
,
samples
,
self
.
_sample_rate
,
format
=
'WAV'
,
subtype
=
subtype_map
[
dtype
])
def
superimpose
(
self
,
other
):
"""Add samples from another segment to those of this segment
(sample-wise addition, not segment concatenation).
Note that this is an in-place transformation.
:param other: Segment containing samples to be added in.
:type other: AudioSegments
:raise TypeError: If type of two segments don't match.
:raise ValueError: If the sample rates of the two segments are not
equal, or if the lengths of segments don't match.
"""
if
isinstance
(
other
,
type
(
self
)):
raise
TypeError
(
"Cannot add segments of different types: %s "
"and %s."
%
(
type
(
self
),
type
(
other
)))
if
self
.
_sample_rate
!=
other
.
_sample_rate
:
raise
ValueError
(
"Sample rates must match to add segments."
)
if
len
(
self
.
_samples
)
!=
len
(
other
.
_samples
):
raise
ValueError
(
"Segment lengths must match to add segments."
)
self
.
_samples
+=
other
.
_samples
def
to_bytes
(
self
,
dtype
=
'float32'
):
"""Create a byte string containing the audio content.
:param dtype: Data type for export samples. Options: 'int16', 'int32',
'float32', 'float64'. Default is 'float32'.
:type dtype: str
:return: Byte string containing audio content.
:rtype: str
"""
samples
=
self
.
_convert_samples_from_float32
(
self
.
_samples
,
dtype
)
return
samples
.
tostring
()
def
to
(
self
,
dtype
=
'int16'
):
"""Create a `dtype` audio content.
:param dtype: Data type for export samples. Options: 'int16', 'int32',
'float32', 'float64'. Default is 'float32'.
:type dtype: str
:return: np.ndarray containing `dtype` audio content.
:rtype: str
"""
samples
=
self
.
_convert_samples_from_float32
(
self
.
_samples
,
dtype
)
return
samples
def
gain_db
(
self
,
gain
):
"""Apply gain in decibels to samples.
Note that this is an in-place transformation.
:param gain: Gain in decibels to apply to samples.
:type gain: float|1darray
"""
self
.
_samples
*=
10.
**
(
gain
/
20.
)
def
change_speed
(
self
,
speed_rate
):
"""Change the audio speed by linear interpolation.
Note that this is an in-place transformation.
:param speed_rate: Rate of speed change:
speed_rate > 1.0, speed up the audio;
speed_rate = 1.0, unchanged;
speed_rate < 1.0, slow down the audio;
speed_rate <= 0.0, not allowed, raise ValueError.
:type speed_rate: float
:raises ValueError: If speed_rate <= 0.0.
"""
if
speed_rate
==
1.0
:
return
if
speed_rate
<=
0
:
raise
ValueError
(
"speed_rate should be greater than zero."
)
# numpy
# old_length = self._samples.shape[0]
# new_length = int(old_length / speed_rate)
# old_indices = np.arange(old_length)
# new_indices = np.linspace(start=0, stop=old_length, num=new_length)
# self._samples = np.interp(new_indices, old_indices, self._samples)
# sox, slow
try
:
import
soxbindings
as
sox
except
ImportError
:
try
:
from
paddlespeech.s2t.utils
import
dynamic_pip_install
package
=
"sox"
dynamic_pip_install
.
install
(
package
)
package
=
"soxbindings"
dynamic_pip_install
.
install
(
package
)
import
soxbindings
as
sox
except
Exception
:
raise
RuntimeError
(
"Can not install soxbindings on your system."
)
tfm
=
sox
.
Transformer
()
tfm
.
set_globals
(
multithread
=
False
)
tfm
.
speed
(
speed_rate
)
self
.
_samples
=
tfm
.
build_array
(
input_array
=
self
.
_samples
,
sample_rate_in
=
self
.
_sample_rate
).
squeeze
(
-
1
).
astype
(
np
.
float32
).
copy
()
def
normalize
(
self
,
target_db
=-
20
,
max_gain_db
=
300.0
):
"""Normalize audio to be of the desired RMS value in decibels.
Note that this is an in-place transformation.
:param target_db: Target RMS value in decibels. This value should be
less than 0.0 as 0.0 is full-scale audio.
:type target_db: float
:param max_gain_db: Max amount of gain in dB that can be applied for
normalization. This is to prevent nans when
attempting to normalize a signal consisting of
all zeros.
:type max_gain_db: float
:raises ValueError: If the required gain to normalize the segment to
the target_db value exceeds max_gain_db.
"""
gain
=
target_db
-
self
.
rms_db
if
gain
>
max_gain_db
:
raise
ValueError
(
"Unable to normalize segment to %f dB because the "
"the probable gain have exceeds max_gain_db (%f dB)"
%
(
target_db
,
max_gain_db
))
self
.
gain_db
(
min
(
max_gain_db
,
target_db
-
self
.
rms_db
))
def
normalize_online_bayesian
(
self
,
target_db
,
prior_db
,
prior_samples
,
startup_delay
=
0.0
):
"""Normalize audio using a production-compatible online/causal
algorithm. This uses an exponential likelihood and gamma prior to
make online estimates of the RMS even when there are very few samples.
Note that this is an in-place transformation.
:param target_db: Target RMS value in decibels.
:type target_bd: float
:param prior_db: Prior RMS estimate in decibels.
:type prior_db: float
:param prior_samples: Prior strength in number of samples.
:type prior_samples: float
:param startup_delay: Default 0.0s. If provided, this function will
accrue statistics for the first startup_delay
seconds before applying online normalization.
:type startup_delay: float
"""
# Estimate total RMS online.
startup_sample_idx
=
min
(
self
.
num_samples
-
1
,
int
(
self
.
sample_rate
*
startup_delay
))
prior_mean_squared
=
10.
**
(
prior_db
/
10.
)
prior_sum_of_squares
=
prior_mean_squared
*
prior_samples
cumsum_of_squares
=
np
.
cumsum
(
self
.
samples
**
2
)
sample_count
=
np
.
arange
(
self
.
num_samples
)
+
1
if
startup_sample_idx
>
0
:
cumsum_of_squares
[:
startup_sample_idx
]
=
\
cumsum_of_squares
[
startup_sample_idx
]
sample_count
[:
startup_sample_idx
]
=
\
sample_count
[
startup_sample_idx
]
mean_squared_estimate
=
((
cumsum_of_squares
+
prior_sum_of_squares
)
/
(
sample_count
+
prior_samples
))
rms_estimate_db
=
10
*
np
.
log10
(
mean_squared_estimate
)
# Compute required time-varying gain.
gain_db
=
target_db
-
rms_estimate_db
self
.
gain_db
(
gain_db
)
def
resample
(
self
,
target_sample_rate
,
filter
=
'kaiser_best'
):
"""Resample the audio to a target sample rate.
Note that this is an in-place transformation.
:param target_sample_rate: Target sample rate.
:type target_sample_rate: int
:param filter: The resampling filter to use one of {'kaiser_best',
'kaiser_fast'}.
:type filter: str
"""
self
.
_samples
=
resampy
.
resample
(
self
.
samples
,
self
.
sample_rate
,
target_sample_rate
,
filter
=
filter
)
self
.
_sample_rate
=
target_sample_rate
def
pad_silence
(
self
,
duration
,
sides
=
'both'
):
"""Pad this audio sample with a period of silence.
Note that this is an in-place transformation.
:param duration: Length of silence in seconds to pad.
:type duration: float
:param sides: Position for padding:
'beginning' - adds silence in the beginning;
'end' - adds silence in the end;
'both' - adds silence in both the beginning and the end.
:type sides: str
:raises ValueError: If sides is not supported.
"""
if
duration
==
0.0
:
return
self
cls
=
type
(
self
)
silence
=
self
.
make_silence
(
duration
,
self
.
_sample_rate
)
if
sides
==
"beginning"
:
padded
=
cls
.
concatenate
(
silence
,
self
)
elif
sides
==
"end"
:
padded
=
cls
.
concatenate
(
self
,
silence
)
elif
sides
==
"both"
:
padded
=
cls
.
concatenate
(
silence
,
self
,
silence
)
else
:
raise
ValueError
(
"Unknown value for the sides %s"
%
sides
)
self
.
_samples
=
padded
.
_samples
def
shift
(
self
,
shift_ms
):
"""Shift the audio in time. If `shift_ms` is positive, shift with time
advance; if negative, shift with time delay. Silence are padded to
keep the duration unchanged.
Note that this is an in-place transformation.
:param shift_ms: Shift time in millseconds. If positive, shift with
time advance; if negative; shift with time delay.
:type shift_ms: float
:raises ValueError: If shift_ms is longer than audio duration.
"""
if
abs
(
shift_ms
)
/
1000.0
>
self
.
duration
:
raise
ValueError
(
"Absolute value of shift_ms should be smaller "
"than audio duration."
)
shift_samples
=
int
(
shift_ms
*
self
.
_sample_rate
/
1000
)
if
shift_samples
>
0
:
# time advance
self
.
_samples
[:
-
shift_samples
]
=
self
.
_samples
[
shift_samples
:]
self
.
_samples
[
-
shift_samples
:]
=
0
elif
shift_samples
<
0
:
# time delay
self
.
_samples
[
-
shift_samples
:]
=
self
.
_samples
[:
shift_samples
]
self
.
_samples
[:
-
shift_samples
]
=
0
def
subsegment
(
self
,
start_sec
=
None
,
end_sec
=
None
):
"""Cut the AudioSegment between given boundaries.
Note that this is an in-place transformation.
:param start_sec: Beginning of subsegment in seconds.
:type start_sec: float
:param end_sec: End of subsegment in seconds.
:type end_sec: float
:raise ValueError: If start_sec or end_sec is incorrectly set, e.g. out
of bounds in time.
"""
start_sec
=
0.0
if
start_sec
is
None
else
start_sec
end_sec
=
self
.
duration
if
end_sec
is
None
else
end_sec
if
start_sec
<
0.0
:
start_sec
=
self
.
duration
+
start_sec
if
end_sec
<
0.0
:
end_sec
=
self
.
duration
+
end_sec
if
start_sec
<
0.0
:
raise
ValueError
(
"The slice start position (%f s) is out of "
"bounds."
%
start_sec
)
if
end_sec
<
0.0
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds."
%
end_sec
)
if
start_sec
>
end_sec
:
raise
ValueError
(
"The slice start position (%f s) is later than "
"the end position (%f s)."
%
(
start_sec
,
end_sec
))
if
end_sec
>
self
.
duration
:
raise
ValueError
(
"The slice end position (%f s) is out of bounds "
"(> %f s)"
%
(
end_sec
,
self
.
duration
))
start_sample
=
int
(
round
(
start_sec
*
self
.
_sample_rate
))
end_sample
=
int
(
round
(
end_sec
*
self
.
_sample_rate
))
self
.
_samples
=
self
.
_samples
[
start_sample
:
end_sample
]
def
random_subsegment
(
self
,
subsegment_length
,
rng
=
None
):
"""Cut the specified length of the audiosegment randomly.
Note that this is an in-place transformation.
:param subsegment_length: Subsegment length in seconds.
:type subsegment_length: float
:param rng: Random number generator state.
:type rng: random.Random
:raises ValueError: If the length of subsegment is greater than
the origineal segemnt.
"""
rng
=
random
.
Random
()
if
rng
is
None
else
rng
if
subsegment_length
>
self
.
duration
:
raise
ValueError
(
"Length of subsegment must not be greater "
"than original segment."
)
start_time
=
rng
.
uniform
(
0.0
,
self
.
duration
-
subsegment_length
)
self
.
subsegment
(
start_time
,
start_time
+
subsegment_length
)
def
convolve
(
self
,
impulse_segment
,
allow_resample
=
False
):
"""Convolve this audio segment with the given impulse segment.
Note that this is an in-place transformation.
:param impulse_segment: Impulse response segments.
:type impulse_segment: AudioSegment
:param allow_resample: Indicates whether resampling is allowed when
the impulse_segment has a different sample
rate from this signal.
:type allow_resample: bool
:raises ValueError: If the sample rate is not match between two
audio segments when resample is not allowed.
"""
if
allow_resample
and
self
.
sample_rate
!=
impulse_segment
.
sample_rate
:
impulse_segment
.
resample
(
self
.
sample_rate
)
if
self
.
sample_rate
!=
impulse_segment
.
sample_rate
:
raise
ValueError
(
"Impulse segment's sample rate (%d Hz) is not "
"equal to base signal sample rate (%d Hz)."
%
(
impulse_segment
.
sample_rate
,
self
.
sample_rate
))
samples
=
signal
.
fftconvolve
(
self
.
samples
,
impulse_segment
.
samples
,
"full"
)
self
.
_samples
=
samples
def
convolve_and_normalize
(
self
,
impulse_segment
,
allow_resample
=
False
):
"""Convolve and normalize the resulting audio segment so that it
has the same average power as the input signal.
Note that this is an in-place transformation.
:param impulse_segment: Impulse response segments.
:type impulse_segment: AudioSegment
:param allow_resample: Indicates whether resampling is allowed when
the impulse_segment has a different sample
rate from this signal.
:type allow_resample: bool
"""
target_db
=
self
.
rms_db
self
.
convolve
(
impulse_segment
,
allow_resample
=
allow_resample
)
self
.
normalize
(
target_db
)
def
add_noise
(
self
,
noise
,
snr_dB
,
allow_downsampling
=
False
,
max_gain_db
=
300.0
,
rng
=
None
):
"""Add the given noise segment at a specific signal-to-noise ratio.
If the noise segment is longer than this segment, a random subsegment
of matching length is sampled from it and used instead.
Note that this is an in-place transformation.
:param noise: Noise signal to add.
:type noise: AudioSegment
:param snr_dB: Signal-to-Noise Ratio, in decibels.
:type snr_dB: float
:param allow_downsampling: Whether to allow the noise signal to be
downsampled to match the base signal sample
rate.
:type allow_downsampling: bool
:param max_gain_db: Maximum amount of gain to apply to noise signal
before adding it in. This is to prevent attempting
to apply infinite gain to a zero signal.
:type max_gain_db: float
:param rng: Random number generator state.
:type rng: None|random.Random
:raises ValueError: If the sample rate does not match between the two
audio segments when downsampling is not allowed, or
if the duration of noise segments is shorter than
original audio segments.
"""
rng
=
random
.
Random
()
if
rng
is
None
else
rng
if
allow_downsampling
and
noise
.
sample_rate
>
self
.
sample_rate
:
noise
=
noise
.
resample
(
self
.
sample_rate
)
if
noise
.
sample_rate
!=
self
.
sample_rate
:
raise
ValueError
(
"Noise sample rate (%d Hz) is not equal to base "
"signal sample rate (%d Hz)."
%
(
noise
.
sample_rate
,
self
.
sample_rate
))
if
noise
.
duration
<
self
.
duration
:
raise
ValueError
(
"Noise signal (%f sec) must be at least as long as"
" base signal (%f sec)."
%
(
noise
.
duration
,
self
.
duration
))
noise_gain_db
=
min
(
self
.
rms_db
-
noise
.
rms_db
-
snr_dB
,
max_gain_db
)
noise_new
=
copy
.
deepcopy
(
noise
)
noise_new
.
random_subsegment
(
self
.
duration
,
rng
=
rng
)
noise_new
.
gain_db
(
noise_gain_db
)
self
.
superimpose
(
noise_new
)
@
property
def
samples
(
self
):
"""Return audio samples.
:return: Audio samples.
:rtype: ndarray
"""
return
self
.
_samples
.
copy
()
@
property
def
sample_rate
(
self
):
"""Return audio sample rate.
:return: Audio sample rate.
:rtype: int
"""
return
self
.
_sample_rate
@
property
def
num_samples
(
self
):
"""Return number of samples.
:return: Number of samples.
:rtype: int
"""
return
self
.
_samples
.
shape
[
0
]
@
property
def
duration
(
self
):
"""Return audio duration.
:return: Audio duration in seconds.
:rtype: float
"""
return
self
.
_samples
.
shape
[
0
]
/
float
(
self
.
_sample_rate
)
@
property
def
rms_db
(
self
):
"""Return root mean square energy of the audio in decibels.
:return: Root mean square energy in decibels.
:rtype: float
"""
# square root => multiply by 10 instead of 20 for dBs
mean_square
=
np
.
mean
(
self
.
_samples
**
2
)
return
10
*
np
.
log10
(
mean_square
)
def
_convert_samples_to_float32
(
self
,
samples
):
"""Convert sample type to float32.
Audio sample type is usually integer or float-point.
Integers will be scaled to [-1, 1] in float32.
"""
return
convert_samples_to_float32
(
samples
)
def
_convert_samples_from_float32
(
self
,
samples
,
dtype
):
"""Convert sample type from float32 to dtype.
Audio sample type is usually integer or float-point. For integer
type, float32 will be rescaled from [-1, 1] to the maximum range
supported by the integer type.
This is for writing a audio file.
"""
return
convert_samples_from_float32
(
samples
,
dtype
)
deepspeech2/s2t/frontend/augmentor/__init__.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
deepspeech2/s2t/frontend/augmentor/augmentation.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the data augmentation pipeline."""
import
json
import
os
from
collections.abc
import
Sequence
from
inspect
import
signature
import
numpy
as
np
__all__
=
[
"AugmentationPipeline"
]
class
AugmentationPipeline
():
"""Build a pre-processing pipeline with various augmentation models.Such a
data augmentation pipeline is oftern leveraged to augment the training
samples to make the model invariant to certain types of perturbations in the
real world, improving model's generalization ability.
The pipeline is built according the the augmentation configuration in json
string, e.g.
.. code-block::
[ {
"type": "noise",
"params": {"min_snr_dB": 10,
"max_snr_dB": 20,
"noise_manifest_path": "datasets/manifest.noise"},
"prob": 0.0
},
{
"type": "speed",
"params": {"min_speed_rate": 0.9,
"max_speed_rate": 1.1},
"prob": 1.0
},
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
},
{
"type": "volume",
"params": {"min_gain_dBFS": -10,
"max_gain_dBFS": 10},
"prob": 0.0
},
{
"type": "bayesian_normal",
"params": {"target_db": -20,
"prior_db": -20,
"prior_samples": 100},
"prob": 0.0
}
]
This augmentation configuration inserts two augmentation models
into the pipeline, with one is VolumePerturbAugmentor and the other
SpeedPerturbAugmentor. "prob" indicates the probability of the current
augmentor to take effect. If "prob" is zero, the augmentor does not take
effect.
Params:
preprocess_conf(str): Augmentation configuration in `json file` or `json string`.
random_seed(int): Random seed.
Raises:
ValueError: If the augmentation json config is in incorrect format".
"""
SPEC_TYPES
=
{
'specaug'
}
def
__init__
(
self
,
preprocess_conf
:
str
,
random_seed
:
int
=
0
):
self
.
_rng
=
np
.
random
.
RandomState
(
random_seed
)
self
.
conf
=
{
'mode'
:
'sequential'
,
'process'
:
[]}
if
preprocess_conf
:
if
os
.
path
.
isfile
(
preprocess_conf
):
with
open
(
preprocess_conf
,
'r'
)
as
fin
:
json_string
=
fin
.
read
()
else
:
json_string
=
preprocess_conf
process
=
json
.
loads
(
json_string
)
self
.
conf
[
'process'
]
+=
process
self
.
_augmentors
,
self
.
_rates
=
self
.
_parse_pipeline_from
(
'all'
)
self
.
_audio_augmentors
,
self
.
_audio_rates
=
self
.
_parse_pipeline_from
(
'audio'
)
self
.
_spec_augmentors
,
self
.
_spec_rates
=
self
.
_parse_pipeline_from
(
'feature'
)
def
__call__
(
self
,
xs
,
uttid_list
=
None
,
**
kwargs
):
if
not
isinstance
(
xs
,
Sequence
):
is_batch
=
False
xs
=
[
xs
]
else
:
is_batch
=
True
if
isinstance
(
uttid_list
,
str
):
uttid_list
=
[
uttid_list
for
_
in
range
(
len
(
xs
))]
if
self
.
conf
.
get
(
"mode"
,
"sequential"
)
==
"sequential"
:
for
idx
,
(
func
,
rate
)
in
enumerate
(
zip
(
self
.
_augmentors
,
self
.
_rates
),
0
):
if
self
.
_rng
.
uniform
(
0.
,
1.
)
>=
rate
:
continue
# Derive only the args which the func has
try
:
param
=
signature
(
func
).
parameters
except
ValueError
:
# Some function, e.g. built-in function, are failed
param
=
{}
_kwargs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
in
param
}
try
:
if
uttid_list
is
not
None
and
"uttid"
in
param
:
xs
=
[
func
(
x
,
u
,
**
_kwargs
)
for
x
,
u
in
zip
(
xs
,
uttid_list
)
]
else
:
xs
=
[
func
(
x
,
**
_kwargs
)
for
x
in
xs
]
except
Exception
:
logger
.
fatal
(
"Catch a exception from {}th func: {}"
.
format
(
idx
,
func
))
raise
else
:
raise
NotImplementedError
(
"Not supporting mode={}"
.
format
(
self
.
conf
[
"mode"
]))
if
is_batch
:
return
xs
else
:
return
xs
[
0
]
def
transform_audio
(
self
,
audio_segment
):
"""Run the pre-processing pipeline for data augmentation.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
for
augmentor
,
rate
in
zip
(
self
.
_audio_augmentors
,
self
.
_audio_rates
):
if
self
.
_rng
.
uniform
(
0.
,
1.
)
<
rate
:
augmentor
.
transform_audio
(
audio_segment
)
def
transform_feature
(
self
,
spec_segment
):
"""spectrogram augmentation.
Args:
spec_segment (np.ndarray): audio feature, (D, T).
"""
for
augmentor
,
rate
in
zip
(
self
.
_spec_augmentors
,
self
.
_spec_rates
):
if
self
.
_rng
.
uniform
(
0.
,
1.
)
<
rate
:
spec_segment
=
augmentor
.
transform_feature
(
spec_segment
)
return
spec_segment
def
_parse_pipeline_from
(
self
,
aug_type
=
'all'
):
"""Parse the config json to build a augmentation pipelien."""
assert
aug_type
in
(
'audio'
,
'feature'
,
'all'
),
aug_type
audio_confs
=
[]
feature_confs
=
[]
all_confs
=
[]
for
config
in
self
.
conf
[
'process'
]:
all_confs
.
append
(
config
)
if
config
[
"type"
]
in
self
.
SPEC_TYPES
:
feature_confs
.
append
(
config
)
else
:
audio_confs
.
append
(
config
)
if
aug_type
==
'audio'
:
aug_confs
=
audio_confs
elif
aug_type
==
'feature'
:
aug_confs
=
feature_confs
elif
aug_type
==
'all'
:
aug_confs
=
all_confs
else
:
raise
ValueError
(
f
"Not support:
{
aug_type
}
"
)
augmentors
=
[
self
.
_get_augmentor
(
config
[
"type"
],
config
[
"params"
])
for
config
in
aug_confs
]
rates
=
[
config
[
"prob"
]
for
config
in
aug_confs
]
return
augmentors
,
rates
deepspeech2/s2t/frontend/augmentor/base.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the abstract base class for augmentation models."""
from
abc
import
ABCMeta
from
abc
import
abstractmethod
class
AugmentorBase
():
"""Abstract base class for augmentation model (augmentor) class.
All augmentor classes should inherit from this class, and implement the
following abstract methods.
"""
__metaclass__
=
ABCMeta
@
abstractmethod
def
__init__
(
self
):
pass
@
abstractmethod
def
__call__
(
self
,
xs
):
raise
NotImplementedError
(
"AugmentorBase: Not impl __call__"
)
@
abstractmethod
def
transform_audio
(
self
,
audio_segment
):
"""Adds various effects to the input audio segment. Such effects
will augment the training data to make the model invariant to certain
types of perturbations in the real world, improving model's
generalization ability.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
raise
NotImplementedError
(
"AugmentorBase: Not impl transform_audio"
)
@
abstractmethod
def
transform_feature
(
self
,
spec_segment
):
"""Adds various effects to the input audo feature segment. Such effects
will augment the training data to make the model invariant to certain
types of time_mask or freq_mask in the real world, improving model's
generalization ability.
Args:
spec_segment (Spectrogram): Spectrogram segment to add effects to.
"""
raise
NotImplementedError
(
"AugmentorBase: Not impl transform_feature"
)
deepspeech2/s2t/frontend/featurizer/__init__.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.audio_featurizer
import
AudioFeaturizer
# noqa: F401
from
.speech_featurizer
import
SpeechFeaturizer
from
.text_featurizer
import
TextFeaturizer
deepspeech2/s2t/frontend/featurizer/audio_featurizer.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the audio featurizer class."""
import
numpy
as
np
from
python_speech_features
import
delta
from
python_speech_features
import
logfbank
from
python_speech_features
import
mfcc
class
AudioFeaturizer
():
"""Audio featurizer, for extracting features from audio contents of
AudioSegment or SpeechSegment.
Currently, it supports feature types of linear spectrogram and mfcc.
:param spectrum_type: Specgram feature type. Options: 'linear'.
:type spectrum_type: str
:param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float
:param max_freq: When spectrum_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are
returned; when spectrum_type is 'mfcc', max_feq is the
highest band edge of mel filters.
:types max_freq: None|float
:param target_sample_rate: Audio are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
def
__init__
(
self
,
spectrum_type
:
str
=
'linear'
,
feat_dim
:
int
=
None
,
delta_delta
:
bool
=
False
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
n_fft
=
None
,
max_freq
=
None
,
target_sample_rate
=
16000
,
use_dB_normalization
=
True
,
target_dB
=-
20
,
dither
=
1.0
):
self
.
_spectrum_type
=
spectrum_type
# mfcc and fbank using `feat_dim`
self
.
_feat_dim
=
feat_dim
# mfcc and fbank using `delta-delta`
self
.
_delta_delta
=
delta_delta
self
.
_stride_ms
=
stride_ms
self
.
_window_ms
=
window_ms
self
.
_max_freq
=
max_freq
self
.
_target_sample_rate
=
target_sample_rate
self
.
_use_dB_normalization
=
use_dB_normalization
self
.
_target_dB
=
target_dB
self
.
_fft_point
=
n_fft
self
.
_dither
=
dither
def
featurize
(
self
,
audio_segment
,
allow_downsampling
=
True
,
allow_upsampling
=
True
):
"""Extract audio features from AudioSegment or SpeechSegment.
:param audio_segment: Audio/speech segment to extract features from.
:type audio_segment: AudioSegment|SpeechSegment
:param allow_downsampling: Whether to allow audio downsampling before
featurizing.
:type allow_downsampling: bool
:param allow_upsampling: Whether to allow audio upsampling before
featurizing.
:type allow_upsampling: bool
:return: Spectrogram audio feature in 2darray.
:rtype: ndarray
:raises ValueError: If audio sample rate is not supported.
"""
# upsampling or downsampling
if
((
audio_segment
.
sample_rate
>
self
.
_target_sample_rate
and
allow_downsampling
)
or
(
audio_segment
.
sample_rate
<
self
.
_target_sample_rate
and
allow_upsampling
)):
audio_segment
.
resample
(
self
.
_target_sample_rate
)
if
audio_segment
.
sample_rate
!=
self
.
_target_sample_rate
:
raise
ValueError
(
"Audio sample rate is not supported. "
"Turn allow_downsampling or allow up_sampling on."
)
# decibel normalization
if
self
.
_use_dB_normalization
:
audio_segment
.
normalize
(
target_db
=
self
.
_target_dB
)
# extract spectrogram
return
self
.
_compute_specgram
(
audio_segment
)
@
property
def
stride_ms
(
self
):
return
self
.
_stride_ms
@
property
def
feature_size
(
self
):
"""audio feature size"""
feat_dim
=
0
if
self
.
_spectrum_type
==
'linear'
:
fft_point
=
self
.
_window_ms
if
self
.
_fft_point
is
None
else
self
.
_fft_point
feat_dim
=
int
(
fft_point
*
(
self
.
_target_sample_rate
/
1000
)
/
2
+
1
)
elif
self
.
_spectrum_type
==
'mfcc'
:
# mfcc, delta, delta-delta
feat_dim
=
int
(
self
.
_feat_dim
*
3
)
if
self
.
_delta_delta
else
int
(
self
.
_feat_dim
)
elif
self
.
_spectrum_type
==
'fbank'
:
# fbank, delta, delta-delta
feat_dim
=
int
(
self
.
_feat_dim
*
3
)
if
self
.
_delta_delta
else
int
(
self
.
_feat_dim
)
else
:
raise
ValueError
(
"Unknown spectrum_type %s. "
"Supported values: linear."
%
self
.
_spectrum_type
)
return
feat_dim
def
_compute_specgram
(
self
,
audio_segment
):
"""Extract various audio features."""
sample_rate
=
audio_segment
.
sample_rate
if
self
.
_spectrum_type
==
'linear'
:
samples
=
audio_segment
.
samples
return
self
.
_compute_linear_specgram
(
samples
,
sample_rate
,
stride_ms
=
self
.
_stride_ms
,
window_ms
=
self
.
_window_ms
,
max_freq
=
self
.
_max_freq
)
elif
self
.
_spectrum_type
==
'mfcc'
:
samples
=
audio_segment
.
to
(
'int16'
)
return
self
.
_compute_mfcc
(
samples
,
sample_rate
,
feat_dim
=
self
.
_feat_dim
,
stride_ms
=
self
.
_stride_ms
,
window_ms
=
self
.
_window_ms
,
max_freq
=
self
.
_max_freq
,
dither
=
self
.
_dither
,
delta_delta
=
self
.
_delta_delta
)
elif
self
.
_spectrum_type
==
'fbank'
:
samples
=
audio_segment
.
to
(
'int16'
)
return
self
.
_compute_fbank
(
samples
,
sample_rate
,
feat_dim
=
self
.
_feat_dim
,
stride_ms
=
self
.
_stride_ms
,
window_ms
=
self
.
_window_ms
,
max_freq
=
self
.
_max_freq
,
dither
=
self
.
_dither
,
delta_delta
=
self
.
_delta_delta
)
else
:
raise
ValueError
(
"Unknown spectrum_type %s. "
"Supported values: linear."
%
self
.
_spectrum_type
)
def
_specgram_real
(
self
,
samples
,
window_size
,
stride_size
,
sample_rate
):
"""Compute the spectrogram for samples from a real signal."""
# extract strided windows
truncate_size
=
(
len
(
samples
)
-
window_size
)
%
stride_size
samples
=
samples
[:
len
(
samples
)
-
truncate_size
]
nshape
=
(
window_size
,
(
len
(
samples
)
-
window_size
)
//
stride_size
+
1
)
nstrides
=
(
samples
.
strides
[
0
],
samples
.
strides
[
0
]
*
stride_size
)
windows
=
np
.
lib
.
stride_tricks
.
as_strided
(
samples
,
shape
=
nshape
,
strides
=
nstrides
)
assert
np
.
all
(
windows
[:,
1
]
==
samples
[
stride_size
:(
stride_size
+
window_size
)])
# window weighting, squared Fast Fourier Transform (fft), scaling
weighting
=
np
.
hanning
(
window_size
)[:,
None
]
# https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html
fft
=
np
.
fft
.
rfft
(
windows
*
weighting
,
n
=
None
,
axis
=
0
)
fft
=
np
.
absolute
(
fft
)
fft
=
fft
**
2
scale
=
np
.
sum
(
weighting
**
2
)
*
sample_rate
fft
[
1
:
-
1
,
:]
*=
(
2.0
/
scale
)
fft
[(
0
,
-
1
),
:]
/=
scale
# prepare fft frequency list
freqs
=
float
(
sample_rate
)
/
window_size
*
np
.
arange
(
fft
.
shape
[
0
])
return
fft
,
freqs
def
_compute_linear_specgram
(
self
,
samples
,
sample_rate
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
max_freq
=
None
,
eps
=
1e-14
):
"""Compute the linear spectrogram from FFT energy.
Args:
samples ([type]): [description]
sample_rate ([type]): [description]
stride_ms (float, optional): [description]. Defaults to 10.0.
window_ms (float, optional): [description]. Defaults to 20.0.
max_freq ([type], optional): [description]. Defaults to None.
eps ([type], optional): [description]. Defaults to 1e-14.
Raises:
ValueError: [description]
ValueError: [description]
Returns:
np.ndarray: log spectrogram, (time, freq)
"""
if
max_freq
is
None
:
max_freq
=
sample_rate
/
2
if
max_freq
>
sample_rate
/
2
:
raise
ValueError
(
"max_freq must not be greater than half of "
"sample rate."
)
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
stride_size
=
int
(
0.001
*
sample_rate
*
stride_ms
)
window_size
=
int
(
0.001
*
sample_rate
*
window_ms
)
specgram
,
freqs
=
self
.
_specgram_real
(
samples
,
window_size
=
window_size
,
stride_size
=
stride_size
,
sample_rate
=
sample_rate
)
ind
=
np
.
where
(
freqs
<=
max_freq
)[
0
][
-
1
]
+
1
# (freq, time)
spec
=
np
.
log
(
specgram
[:
ind
,
:]
+
eps
)
return
np
.
transpose
(
spec
)
def
_concat_delta_delta
(
self
,
feat
):
"""append delat, delta-delta feature.
Args:
feat (np.ndarray): (T, D)
Returns:
np.ndarray: feat with delta-delta, (T, 3*D)
"""
# Deltas
d_feat
=
delta
(
feat
,
2
)
# Deltas-Deltas
dd_feat
=
delta
(
feat
,
2
)
# concat above three features
concat_feat
=
np
.
concatenate
((
feat
,
d_feat
,
dd_feat
),
axis
=
1
)
return
concat_feat
def
_compute_mfcc
(
self
,
samples
,
sample_rate
,
feat_dim
=
13
,
stride_ms
=
10.0
,
window_ms
=
25.0
,
max_freq
=
None
,
dither
=
1.0
,
delta_delta
=
True
):
"""Compute mfcc from samples.
Args:
samples (np.ndarray, np.int16): the audio signal from which to compute features.
sample_rate (float): the sample rate of the signal we are working with, in Hz.
feat_dim (int): the number of cepstrum to return, default 13.
stride_ms (float, optional): stride length in ms. Defaults to 10.0.
window_ms (float, optional): window length in ms. Defaults to 25.0.
max_freq ([type], optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None.
delta_delta (bool, optional): Whether with delta delta. Defaults to False.
Raises:
ValueError: max_freq > samplerate/2
ValueError: stride_ms > window_ms
Returns:
np.ndarray: mfcc feature, (D, T).
"""
if
max_freq
is
None
:
max_freq
=
sample_rate
/
2
if
max_freq
>
sample_rate
/
2
:
raise
ValueError
(
"max_freq must not be greater than half of "
"sample rate."
)
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
# compute the 13 cepstral coefficients, and the first one is replaced
# by log(frame energy), (T, D)
mfcc_feat
=
mfcc
(
signal
=
samples
,
samplerate
=
sample_rate
,
winlen
=
0.001
*
window_ms
,
winstep
=
0.001
*
stride_ms
,
numcep
=
feat_dim
,
nfilt
=
23
,
nfft
=
512
,
lowfreq
=
20
,
highfreq
=
max_freq
,
dither
=
dither
,
remove_dc_offset
=
True
,
preemph
=
0.97
,
ceplifter
=
22
,
useEnergy
=
True
,
winfunc
=
'povey'
)
if
delta_delta
:
mfcc_feat
=
self
.
_concat_delta_delta
(
mfcc_feat
)
return
mfcc_feat
def
_compute_fbank
(
self
,
samples
,
sample_rate
,
feat_dim
=
40
,
stride_ms
=
10.0
,
window_ms
=
25.0
,
max_freq
=
None
,
dither
=
1.0
,
delta_delta
=
False
):
"""Compute logfbank from samples.
Args:
samples (np.ndarray, np.int16): the audio signal from which to compute features. Should be an N*1 array
sample_rate (float): the sample rate of the signal we are working with, in Hz.
feat_dim (int): the number of cepstrum to return, default 13.
stride_ms (float, optional): stride length in ms. Defaults to 10.0.
window_ms (float, optional): window length in ms. Defaults to 20.0.
max_freq (float, optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None.
delta_delta (bool, optional): Whether with delta delta. Defaults to False.
Raises:
ValueError: max_freq > samplerate/2
ValueError: stride_ms > window_ms
Returns:
np.ndarray: mfcc feature, (D, T).
"""
if
max_freq
is
None
:
max_freq
=
sample_rate
/
2
if
max_freq
>
sample_rate
/
2
:
raise
ValueError
(
"max_freq must not be greater than half of "
"sample rate."
)
if
stride_ms
>
window_ms
:
raise
ValueError
(
"Stride size must not be greater than "
"window size."
)
# (T, D)
fbank_feat
=
logfbank
(
signal
=
samples
,
samplerate
=
sample_rate
,
winlen
=
0.001
*
window_ms
,
winstep
=
0.001
*
stride_ms
,
nfilt
=
feat_dim
,
nfft
=
512
,
lowfreq
=
20
,
highfreq
=
max_freq
,
dither
=
dither
,
remove_dc_offset
=
True
,
preemph
=
0.97
,
wintype
=
'povey'
)
if
delta_delta
:
fbank_feat
=
self
.
_concat_delta_delta
(
fbank_feat
)
return
fbank_feat
deepspeech2/s2t/frontend/featurizer/speech_featurizer.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the speech featurizer class."""
from
.audio_featurizer
import
AudioFeaturizer
from
.text_featurizer
import
TextFeaturizer
class
SpeechFeaturizer
():
def
__init__
(
self
,
unit_type
,
vocab_filepath
,
spm_model_prefix
=
None
,
spectrum_type
=
'linear'
,
feat_dim
=
None
,
delta_delta
=
False
,
stride_ms
=
10.0
,
window_ms
=
20.0
,
n_fft
=
None
,
max_freq
=
None
,
target_sample_rate
=
16000
,
use_dB_normalization
=
True
,
target_dB
=-
20
,
dither
=
1.0
,
maskctc
=
False
):
self
.
stride_ms
=
stride_ms
self
.
window_ms
=
window_ms
self
.
audio_feature
=
AudioFeaturizer
(
spectrum_type
=
spectrum_type
,
feat_dim
=
feat_dim
,
delta_delta
=
delta_delta
,
stride_ms
=
stride_ms
,
window_ms
=
window_ms
,
n_fft
=
n_fft
,
max_freq
=
max_freq
,
target_sample_rate
=
target_sample_rate
,
use_dB_normalization
=
use_dB_normalization
,
target_dB
=
target_dB
,
dither
=
dither
)
self
.
feature_size
=
self
.
audio_feature
.
feature_size
self
.
text_feature
=
TextFeaturizer
(
unit_type
=
unit_type
,
vocab
=
vocab_filepath
,
spm_model_prefix
=
spm_model_prefix
,
maskctc
=
maskctc
)
self
.
vocab_size
=
self
.
text_feature
.
vocab_size
def
featurize
(
self
,
speech_segment
,
keep_transcription_text
):
"""Extract features for speech segment.
1. For audio parts, extract the audio features.
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
Args:
speech_segment (SpeechSegment): Speech segment to extract features from.
keep_transcription_text (bool): True, keep transcript text, False, token ids
Returns:
tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices.
"""
spec_feature
=
self
.
audio_feature
.
featurize
(
speech_segment
)
if
keep_transcription_text
:
return
spec_feature
,
speech_segment
.
transcript
if
speech_segment
.
has_token
:
text_ids
=
speech_segment
.
token_ids
else
:
text_ids
=
self
.
text_feature
.
featurize
(
speech_segment
.
transcript
)
return
spec_feature
,
text_ids
def
text_featurize
(
self
,
text
,
keep_transcription_text
):
"""Extract features for speech segment.
1. For audio parts, extract the audio features.
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
Args:
text (str): text.
keep_transcription_text (bool): True, keep transcript text, False, token ids
Returns:
(str|List[int]): text, or list of token indices.
"""
if
keep_transcription_text
:
return
text
text_ids
=
self
.
text_feature
.
featurize
(
text
)
return
text_ids
deepspeech2/s2t/frontend/featurizer/text_featurizer.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the text featurizer class."""
from
typing
import
Union
import
sentencepiece
as
spm
from
..utility
import
BLANK
,
EOS
,
MASKCTC
,
SOS
,
SPACE
,
UNK
,
load_dict
__all__
=
[
"TextFeaturizer"
]
class
TextFeaturizer
():
def
__init__
(
self
,
unit_type
,
vocab
,
spm_model_prefix
=
None
,
maskctc
=
False
):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the
given vocabulary file.
Args:
unit_type (str): unit type, e.g. char, word, spm
vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list.
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
"""
assert
unit_type
in
(
'char'
,
'spm'
,
'word'
)
self
.
unit_type
=
unit_type
self
.
unk
=
UNK
self
.
maskctc
=
maskctc
if
vocab
:
self
.
vocab_dict
,
self
.
_id2token
,
self
.
vocab_list
,
\
self
.
unk_id
,
self
.
eos_id
,
\
self
.
blank_id
=
self
.
_load_vocabulary_from_file
(
vocab
,
maskctc
)
self
.
vocab_size
=
len
(
self
.
vocab_list
)
else
:
print
(
"TextFeaturizer: not have vocab file or vocab list."
)
if
unit_type
==
'spm'
:
spm_model
=
spm_model_prefix
+
'.model'
self
.
sp
=
spm
.
SentencePieceProcessor
()
self
.
sp
.
Load
(
spm_model
)
def
tokenize
(
self
,
text
,
replace_space
=
True
):
if
self
.
unit_type
==
'char'
:
tokens
=
self
.
char_tokenize
(
text
,
replace_space
)
elif
self
.
unit_type
==
'word'
:
tokens
=
self
.
word_tokenize
(
text
)
else
:
# spm
tokens
=
self
.
spm_tokenize
(
text
)
return
tokens
def
detokenize
(
self
,
tokens
):
if
self
.
unit_type
==
'char'
:
text
=
self
.
char_detokenize
(
tokens
)
elif
self
.
unit_type
==
'word'
:
text
=
self
.
word_detokenize
(
tokens
)
else
:
# spm
text
=
self
.
spm_detokenize
(
tokens
)
return
text
def
featurize
(
self
,
text
):
"""Convert text string to a list of token indices.
Args:
text (str): Text to process.
Returns:
List[int]: List of token indices.
"""
tokens
=
self
.
tokenize
(
text
)
ids
=
[]
for
token
in
tokens
:
if
token
not
in
self
.
vocab_dict
:
token
=
self
.
unk
ids
.
append
(
self
.
vocab_dict
[
token
])
return
ids
def
defeaturize
(
self
,
idxs
):
"""Convert a list of token indices to text string,
ignore index after eos_id.
Args:
idxs (List[int]): List of token indices.
Returns:
str: Text.
"""
tokens
=
[]
for
idx
in
idxs
:
if
idx
==
self
.
eos_id
:
break
tokens
.
append
(
self
.
_id2token
[
idx
])
text
=
self
.
detokenize
(
tokens
)
return
text
def
char_tokenize
(
self
,
text
,
replace_space
=
True
):
"""Character tokenizer.
Args:
text (str): text string.
replace_space (bool): False only used by build_vocab.py.
Returns:
List[str]: tokens.
"""
text
=
text
.
strip
()
if
replace_space
:
text_list
=
[
SPACE
if
item
==
" "
else
item
for
item
in
list
(
text
)]
else
:
text_list
=
list
(
text
)
return
text_list
def
char_detokenize
(
self
,
tokens
):
"""Character detokenizer.
Args:
tokens (List[str]): tokens.
Returns:
str: text string.
"""
tokens
=
[
t
.
replace
(
SPACE
,
" "
)
for
t
in
tokens
]
return
""
.
join
(
tokens
)
def
word_tokenize
(
self
,
text
):
"""Word tokenizer, separate by <space>."""
return
text
.
strip
().
split
()
def
word_detokenize
(
self
,
tokens
):
"""Word detokenizer, separate by <space>."""
return
" "
.
join
(
tokens
)
def
spm_tokenize
(
self
,
text
):
"""spm tokenize.
Args:
text (str): text string.
Returns:
List[str]: sentence pieces str code
"""
stats
=
{
"num_empty"
:
0
,
"num_filtered"
:
0
}
def
valid
(
line
):
return
True
def
encode
(
l
):
return
self
.
sp
.
EncodeAsPieces
(
l
)
def
encode_line
(
line
):
line
=
line
.
strip
()
if
len
(
line
)
>
0
:
line
=
encode
(
line
)
if
valid
(
line
):
return
line
else
:
stats
[
"num_filtered"
]
+=
1
else
:
stats
[
"num_empty"
]
+=
1
return
None
enc_line
=
encode_line
(
text
)
return
enc_line
def
spm_detokenize
(
self
,
tokens
,
input_format
=
'piece'
):
"""spm detokenize.
Args:
ids (List[str]): tokens.
Returns:
str: text
"""
if
input_format
==
"piece"
:
def
decode
(
l
):
return
""
.
join
(
self
.
sp
.
DecodePieces
(
l
))
elif
input_format
==
"id"
:
def
decode
(
l
):
return
""
.
join
(
self
.
sp
.
DecodeIds
(
l
))
return
decode
(
tokens
)
def
_load_vocabulary_from_file
(
self
,
vocab
:
Union
[
str
,
list
],
maskctc
:
bool
):
"""Load vocabulary from file."""
if
isinstance
(
vocab
,
list
):
vocab_list
=
vocab
else
:
vocab_list
=
load_dict
(
vocab
,
maskctc
)
assert
vocab_list
is
not
None
id2token
=
dict
(
[(
idx
,
token
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
token2id
=
dict
(
[(
token
,
idx
)
for
(
idx
,
token
)
in
enumerate
(
vocab_list
)])
blank_id
=
vocab_list
.
index
(
BLANK
)
if
BLANK
in
vocab_list
else
-
1
unk_id
=
vocab_list
.
index
(
UNK
)
if
UNK
in
vocab_list
else
-
1
eos_id
=
vocab_list
.
index
(
EOS
)
if
EOS
in
vocab_list
else
-
1
return
token2id
,
id2token
,
vocab_list
,
unk_id
,
eos_id
,
blank_id
deepspeech2/s2t/frontend/normalizer.py
0 → 100644
View file @
cdab2875
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains feature normalizers."""
import
numpy
as
np
from
.utility
import
load_cmvn
__all__
=
[
"FeatureNormalizer"
]
class
FeatureNormalizer
(
object
):
def
__init__
(
self
,
mean_std_filepath
):
mean_std
=
mean_std_filepath
self
.
_read_mean_std_from_file
(
mean_std
)
def
apply
(
self
,
features
):
"""Normalize features to be of zero mean and unit stddev.
:param features: Input features to be normalized.
:type features: ndarray, shape (T, D)
:param eps: added to stddev to provide numerical stablibity.
:type eps: float
:return: Normalized features.
:rtype: ndarray
"""
return
(
features
-
self
.
_mean
)
*
self
.
_istd
def
_read_mean_std_from_file
(
self
,
mean_std
,
eps
=
1e-20
):
"""Load mean and std from file."""
if
isinstance
(
mean_std
,
list
):
mean
=
mean_std
[
0
][
'cmvn_stats'
][
'mean'
]
istd
=
mean_std
[
0
][
'cmvn_stats'
][
'istd'
]
else
:
filetype
=
mean_std
.
split
(
"."
)[
-
1
]
mean
,
istd
=
load_cmvn
(
mean_std
,
filetype
=
filetype
)
self
.
_mean
=
np
.
expand_dims
(
mean
,
axis
=
0
)
self
.
_istd
=
np
.
expand_dims
(
istd
,
axis
=
0
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment