Commit cdab2875 authored by SWHL's avatar SWHL
Browse files

Update files

parents
Pipeline #335 failed with stages
in 0 seconds
*.pth
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.pytest_cache
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
# *.manifest
# *.spec
*.res
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
#idea
.vs
.vscode
.idea
#models
*.ttf
*.ttc
*.bin
*.mapping
*.xml
*.pdiparams
*.pdiparams.info
*.pdmodel
.DS_Store
\ No newline at end of file
#### 基于PaddeSpeech训练所得模型的推理代码
- 项目来源:[PaddleSpeech/s2t](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell/asr0)
- 运行环境:Linux| Python 3.7 | CPU | 不依赖Paddle
#### 使用方法
1. 下载整个`python/base_paddlespeech`目录
2. 安装依赖环境
- 批量安装
```bash
pip install -r requirements.txt -i https://pypi.douban.com/simple/
# CentOS
sudo yum install libsndfile
```
3. 下载`resources`模型相关文件到`base_paddlespeech`下,
- 下载`resources`链接:[Google Drive](https://drive.google.com/file/d/1MWmKxsfCNQyQ5CPlaYxJKnYfIIC5OO5L/view?usp=sharing)
- 下载语言模型文件→[下载链接](https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm),放到`base_paddlespeech/resources/models/language_model`目录下
- 最终结构目录如下,请自行比对:
```text
base_paddlespeech
├── deepspeech2
│ ├── infer.py
│ ├── __init__.py
│ └── s2t
│ ├── decoders
│ ├── deepspeech2.py
│ ├── frontend
│ ├── io
│ ├── modules
│ ├── __pycache__
│ ├── transform
│ └── utils
├── main.py
├── requirements.txt
├── resources
│ └── models
│ ├── asr0_deepspeech2_online_aishell_ckpt_0.2.0.onnx
│ ├── language_model
│ │ └── zh_giga.no_cna_cmn.prune01244.klm
│ └── model.yaml
└── test_wav
└── zh.wav
```
4. 运行`python main.py`
5. 运行结果如下:
```text
checking the audio file format......
The sample rate is 16000
The audio file format is right
Preprocess audio_file:/da2/SWHL/test_wav/zh.wav
audio feat shape: (1, 498, 161)
ASR Result: 我认为跑步最重要的就是给我们带来了身体健康
```
#### 模型转onnx代码
```bash
model_dir="pretrained_models/deepspeech2online_aishell-zh-16k/asr0_deepspeech2_online_aishell_ckpt_0.1.1.model.tar/exp/deepspeech2_online/checkpoints"
pdmodel="avg_1.jit.pdmodel"
params_file="avg_1.jit.pdiparams"
save_onnx="pretrained_models/onnx/asr0_deepspeech2_online_aishell_ckpt_0.1.1.onnx"
paddle2onnx --model_dir ${model_dir} \
--model_filename ${pdmodel} \
--params_filename ${params_file} \
--save_file ${save_onnx} \
--opset_version 12
```
# !/usr/bin/env python
# -*- encoding: utf-8 -*-
# @File: __init__.py
# @Author: SWHL
# @Contact: liekkaskono@163.com
from .infer import ASRExecutor
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from collections import OrderedDict
from typing import Union
import numpy as np
import soundfile
from yacs.config import CfgNode
from .s2t.deepspeech2 import DeepSpeech2ModelOnline
from .s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from .s2t.io.collator import SpeechCollator
from .s2t.utils.utility import UpdateConfig
class ASRExecutor(object):
def __init__(self,
sample_rate: int = 16000,
config_path: os.PathLike = None,
onnx_path: os.PathLike = None,
decode_method: str = 'attention_rescoring',
lan_model_path=None):
self.sample_rate = sample_rate
self.config_path = config_path
self.onnx_path = onnx_path
self.decode_method = decode_method
self.lan_model_path = lan_model_path
self._inputs = OrderedDict()
self._outputs = OrderedDict()
self.config_path = os.path.abspath(self.config_path)
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.config_path)))
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.config_path)
with UpdateConfig(self.config):
self.vocab = self.config.vocab_filepath
self.config.decode.lang_model_path = self.lan_model_path
self.collate_fn_test = SpeechCollator.from_config(self.config)
self.text_feature = TextFeaturizer(unit_type=self.config.unit_type,
vocab=self.vocab)
self.model = DeepSpeech2ModelOnline(encoder_onnx_path=self.onnx_path)
def __call__(self, audio_file, force_yes: bool = False):
audio_file = os.path.abspath(audio_file)
if not self._check(audio_file, self.sample_rate, force_yes):
sys.exit(-1)
self.preprocess(audio_file)
res = self.infer()
return res
def preprocess(self, input: Union[str, os.PathLike]):
audio_file = input
if isinstance(audio_file, (str, os.PathLike)):
print("Preprocess audio_file:" + audio_file)
# Get the object for feature extraction
audio, _ = self.collate_fn_test.process_utterance(
audio_file=audio_file, transcript=" ")
audio_len = audio.shape[0]
audio = audio[np.newaxis, ...]
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
print(f"audio feat shape: {audio.shape}")
def infer(self):
"""
Model inference and result stored in self.output.
"""
cfg = self.config.decode
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
decode_batch_size = audio.shape[0]
self.model.decoder.init_decoder(
decode_batch_size, self.text_feature.vocab_list,
cfg.decoding_method, cfg.lang_model_path, cfg.alpha, cfg.beta,
cfg.beam_size, cfg.cutoff_prob, cfg.cutoff_top_n,
cfg.num_proc_bsearch)
result_transcripts = self.model.decode(audio, audio_len)
self.model.decoder.del_decoder()
return result_transcripts[0]
def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
print(
"invalid sample rate, please input --sr 8000 or --sr 16000")
return False
if isinstance(audio_file, (str, os.PathLike)):
if not os.path.isfile(audio_file):
print("Please input the right audio file path")
return False
print("checking the audio file format......")
try:
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
except Exception as e:
print(
"can not open the audio file, please check the audio file format is 'wav'. \n \
you can try to use sox to change the file format.\n \
For example: \n \
sample rate: 16k \n \
sox input_audio.xx --rate 16k --bits 16 --channels 1 output_audio.wav \n \
sample rate: 8k \n \
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
return False
print("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate:
print("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations,\n \
Please input the 16k 16 bit 1 channel wav file. \
".format(self.sample_rate, self.sample_rate))
if force_yes is False:
while (True):
print(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
content = input("Input(Y/N):")
if content.strip() in ["Y", "y", "yes", "Yes"]:
print("change the sampele rate, channel to 16k and 1 channel")
break
elif content.strip() in ["N", "n", "no", "No"]:
print("Exit the program")
exit(1)
else:
print("Not regular input, please input again")
self.change_format = True
else:
print("The audio file format is right")
self.change_format = False
return True
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .swig_wrapper import ctc_beam_search_decoding
from .swig_wrapper import ctc_beam_search_decoding_batch
from .swig_wrapper import ctc_greedy_decoding
from .swig_wrapper import CTCBeamSearchDecoder
from .swig_wrapper import Scorer
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
import multiprocessing
from itertools import groupby
from math import log
import numpy as np
def ctc_greedy_decoder(probs_seq, vocabulary):
"""CTC greedy (best path) decoder.
Path consisting of the most probable tokens are further post-processed to
remove consecutive repetitions and all blanks.
:param probs_seq: 2-D list of probabilities over the vocabulary for each
character. Each element is a list of float probabilities
for one character.
:type probs_seq: list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: baseline
"""
# dimension verification
for probs in probs_seq:
if not len(probs) == len(vocabulary) + 1:
raise ValueError("probs_seq dimension mismatchedd with vocabulary")
# argmax to get the best index for each time step
max_index_list = list(np.array(probs_seq).argmax(axis=1))
# remove consecutive duplicate indexes
index_list = [index_group[0] for index_group in groupby(max_index_list)]
# remove blank indexes
blank_index = len(vocabulary)
index_list = [index for index in index_list if index != blank_index]
# convert index list to string
return ''.join([vocabulary[index] for index in index_list])
def ctc_beam_search_decoder(probs_seq,
beam_size,
vocabulary,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
nproc=False):
"""CTC Beam search decoder.
It utilizes beam search to approximately select top best decoding
labels and returning results in the descending order.
The implementation is based on Prefix Beam Search
(https://arxiv.org/abs/1408.2873), and the unclear part is
redesigned. Two important modifications: 1) in the iterative computation
of probabilities, the assignment operation is changed to accumulation for
one prefix may comes from different paths; 2) the if condition "if l^+ not
in A_prev then" after probabilities' computation is deprecated for it is
hard to understand and seems unnecessary.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:param nproc: Whether the decoder used in multiprocesses.
:type nproc: bool
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
# dimension check
for prob_list in probs_seq:
if not len(prob_list) == len(vocabulary) + 1:
raise ValueError("The shape of prob_seq does not match with the "
"shape of the vocabulary.")
# blank_id assign
blank_id = len(vocabulary)
# If the decoder called in the multiprocesses, then use the global scorer
# instantiated in ctc_beam_search_decoder_batch().
if nproc is True:
global ext_nproc_scorer
ext_scoring_func = ext_nproc_scorer
# initialize
# prefix_set_prev: the set containing selected prefixes
# probs_b_prev: prefixes' probability ending with blank in previous step
# probs_nb_prev: prefixes' probability ending with non-blank in previous
# step
prefix_set_prev = {'\t': 1.0}
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
# extend prefix in loop
for time_step in range(len(probs_seq)):
# prefix_set_next: the set containing candidate prefixes
# probs_b_cur: prefixes' probability ending with blank in current step
# probs_nb_cur: prefixes' probability ending with non-blank in current
# step
prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}
prob_idx = list(enumerate(probs_seq[time_step]))
cutoff_len = len(prob_idx)
# If pruning is enabled
if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len, cum_prob = 0, 0.0
for i in range(len(prob_idx)):
cum_prob += prob_idx[i][1]
cutoff_len += 1
if cum_prob >= cutoff_prob:
break
cutoff_len = min(cutoff_len, cutoff_top_n)
prob_idx = prob_idx[0:cutoff_len]
for l in prefix_set_prev:
if l not in prefix_set_next:
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
# extend prefix by travering prob_idx
for index in range(cutoff_len):
c, prob_c = prob_idx[index][0], prob_idx[index][1]
if c == blank_id:
probs_b_cur[l] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
last_char = l[-1]
new_char = vocabulary[c]
l_plus = l + new_char
if l_plus not in prefix_set_next:
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
if new_char == last_char:
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
elif new_char == ' ':
if (ext_scoring_func is None) or (len(l) == 1):
score = 1.0
else:
prefix = l[1:]
score = ext_scoring_func(prefix)
probs_nb_cur[l_plus] += score * prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
else:
probs_nb_cur[l_plus] += prob_c * (
probs_b_prev[l] + probs_nb_prev[l])
# add l_plus into prefix_set_next
prefix_set_next[l_plus] = probs_nb_cur[
l_plus] + probs_b_cur[l_plus]
# add l into prefix_set_next
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
# update probs
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
# store top beam_size prefixes
prefix_set_prev = sorted(
prefix_set_next.items(), key=lambda asd: asd[1], reverse=True)
if beam_size < len(prefix_set_prev):
prefix_set_prev = prefix_set_prev[:beam_size]
prefix_set_prev = dict(prefix_set_prev)
beam_result = []
for seq, prob in prefix_set_prev.items():
if prob > 0.0 and len(seq) > 1:
result = seq[1:]
# score last word by external scorer
if (ext_scoring_func is not None) and (result[-1] != ' '):
prob = prob * ext_scoring_func(result)
log_prob = log(prob)
beam_result.append((log_prob, result))
else:
beam_result.append((float('-inf'), ''))
# output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result
def ctc_beam_search_decoder_batch(probs_split,
beam_size,
vocabulary,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None):
"""CTC beam search decoder using multiple processes.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param beam_size: Width for beam search.
:type beam_size: int
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param num_processes: Number of parallel processes.
:type num_processes: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
if not num_processes > 0:
raise ValueError("Number of processes must be positive!")
# use global variable to pass the externnal scorer to beam search decoder
global ext_nproc_scorer
ext_nproc_scorer = ext_scoring_func
nproc = True
pool = multiprocessing.Pool(processes=num_processes)
results = []
for i, probs_list in enumerate(probs_split):
args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n,
None, nproc)
results.append(pool.apply_async(ctc_beam_search_decoder, args))
pool.close()
pool.join()
beam_search_results = [result.get() for result in results]
return beam_search_results
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper for various CTC decoders in SWIG."""
import paddlespeech_ctcdecoders
class Scorer(paddlespeech_ctcdecoders.Scorer):
"""Wrapper for Scorer.
:param alpha: Parameter associated with language model. Don't use
language model when alpha = 0.
:type alpha: float
:param beta: Parameter associated with word count. Don't use word
count when beta = 0.
:type beta: float
:model_path: Path to load language model.
:type model_path: str
:param vocabulary: Vocabulary list.
:type vocabulary: list
"""
def __init__(self, alpha, beta, model_path, vocabulary):
paddlespeech_ctcdecoders.Scorer.__init__(self, alpha, beta, model_path,
vocabulary)
def ctc_greedy_decoding(probs_seq, vocabulary, blank_id):
"""Wrapper for ctc best path decodeing function in swig.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:return: Decoding result string.
:rtype: str
"""
result = paddlespeech_ctcdecoders.ctc_greedy_decoding(probs_seq.tolist(),
vocabulary, blank_id)
return result
def ctc_beam_search_decoding(probs_seq,
vocabulary,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the CTC Beam Search Decoding function.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over vocabulary and blank.
:type probs_seq: 2-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_func: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding(
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
ext_scoring_func, blank_id)
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
return beam_results
def ctc_beam_search_decoding_batch(probs_split,
vocabulary,
beam_size,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
ext_scoring_func=None,
blank_id=0):
"""Wrapper for the batched CTC beam search decodeing batch function.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param vocabulary: Vocabulary list.
:type vocabulary: list
:param beam_size: Width for beam search.
:type beam_size: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in vocabulary pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
:type cutoff_top_n: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param ext_scoring_func: External scoring function for
partially decoded sentence, e.g. word count
or language model.
:type external_scoring_function: callable
:return: List of tuples of log probability and sentence as decoding
results, in descending order of the probability.
:rtype: list
"""
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
batch_beam_results = paddlespeech_ctcdecoders.ctc_beam_search_decoding_batch(
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
cutoff_top_n, ext_scoring_func, blank_id)
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
for beam_results in batch_beam_results]
return batch_beam_results
class CTCBeamSearchDecoder(paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch):
"""Wrapper for CtcBeamSearchDecoderBatch.
Args:
vocab_list (list): Vocabulary list.
beam_size (int): Width for beam search.
num_processes (int): Number of parallel processes.
param cutoff_prob (float): Cutoff probability in vocabulary pruning,
default 1.0, no pruning.
cutoff_top_n (int): Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in vocabulary will be
used in beam search, default 40.
param ext_scorer (Scorer): External scorer for partially decoded sentence, e.g. word count
or language model.
"""
def __init__(self, vocab_list, batch_size, beam_size, num_processes,
cutoff_prob, cutoff_top_n, _ext_scorer, blank_id):
paddlespeech_ctcdecoders.CtcBeamSearchDecoderBatch.__init__(
self, vocab_list, batch_size, beam_size, num_processes, cutoff_prob,
cutoff_top_n, _ext_scorer, blank_id)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import numpy as np
__all__ = ["end_detect", "parse_hypothesis", "add_results_to_json"]
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection.
described in Eq. (50) of S. Watanabe et al
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
:param ended_hyps: dict
:param i: int
:param M: int
:param D_end: float
:return: bool
"""
if len(ended_hyps) == 0:
return False
count = 0
best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0]
for m in range(M):
# get ended_hyps with their length is i - m
hyp_length = i - m
hyps_same_length = [
x for x in ended_hyps if len(x["yseq"]) == hyp_length
]
if len(hyps_same_length) > 0:
best_hyp_same_length = sorted(
hyps_same_length, key=lambda x: x["score"], reverse=True)[0]
if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
count += 1
if count == M:
return True
else:
return False
# * ------------------ recognition related ------------------ *
def parse_hypothesis(hyp, char_list):
"""Parse hypothesis.
Args:
hyp (list[dict[str, Any]]): Recognition hypothesis.
char_list (list[str]): List of characters.
Returns:
tuple(str, str, str, float)
"""
# remove sos and get results
tokenid_as_list = list(map(int, hyp["yseq"][1:]))
token_as_list = [char_list[idx] for idx in tokenid_as_list]
score = float(hyp["score"])
# convert to string
tokenid = " ".join([str(idx) for idx in tokenid_as_list])
token = " ".join(token_as_list)
text = "".join(token_as_list).replace("<space>", " ")
return text, token, tokenid, score
def add_results_to_json(js, nbest_hyps, char_list):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
nbest_hyps_sd (list[dict[str, Any]]):
List of hypothesis for multi_speakers: nutts x nspkrs.
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js = dict()
new_js["utt2spk"] = js["utt2spk"]
new_js["output"] = []
for n, hyp in enumerate(nbest_hyps, 1):
# parse hypothesis
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp,
char_list)
# copy ground-truth
if len(js["output"]) > 0:
out_dic = dict(js["output"][0].items())
else:
# for no reference case (e.g., speech translation)
out_dic = {"name": ""}
# update name
out_dic["name"] += "[%d]" % n
# add recognition results
out_dic["rec_text"] = rec_text
out_dic["rec_token"] = rec_token
out_dic["rec_tokenid"] = rec_tokenid
out_dic["score"] = score
# add to list of N-best result dicts
new_js["output"].append(out_dic)
# show 1-best result
if n == 1:
if "text" in out_dic.keys():
print("groundtruth: %s" % out_dic["text"])
print("prediction : %s" % out_dic["rec_text"])
return new_js
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deepspeech2 ASR Online Model"""
import numpy as np
import onnxruntime as ort
from .modules.ctc import CTCDecoder
class DeepSpeech2ModelOnline(object):
def __init__(self, encoder_onnx_path):
self.encoder_sess = ort.InferenceSession(encoder_onnx_path)
self.decoder = CTCDecoder()
def decode(self, audio, audio_len):
onnx_inputs_name = self.encoder_sess.get_inputs()
ort_inputs = {
onnx_inputs_name[0].name: np.array(audio).astype(np.float32),
onnx_inputs_name[1].name: np.array([audio_len]).astype(np.int64),
onnx_inputs_name[2].name: np.zeros([5, 1, 1024]).astype(np.float32),
onnx_inputs_name[3].name: np.zeros([5, 1, 1024]).astype(np.float32)
}
ort_outputs = self.encoder_sess.run(None, ort_inputs)
probs, eouts_len, _, _ = ort_outputs
batch_size = probs.shape[0]
self.decoder.reset_decoder(batch_size=batch_size)
self.decoder.next(probs, eouts_len)
trans_best, trans_beam = self.decoder.decode()
return trans_best
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the audio segment class."""
import copy
import io
import random
import re
import struct
import numpy as np
import resampy
import soundfile
from scipy import signal
from .utility import convert_samples_from_float32
from .utility import convert_samples_to_float32
from .utility import subfile_from_tar
class AudioSegment():
"""Monaural audio segment abstraction.
:param samples: Audio samples [num_samples x num_channels].
:type samples: ndarray.float32
:param sample_rate: Audio sample rate.
:type sample_rate: int
:raises TypeError: If the sample data type is not float or int.
"""
def __init__(self, samples, sample_rate):
"""Create audio segment from samples.
Samples are convert float32 internally, with int scaled to [-1, 1].
"""
self._samples = self._convert_samples_to_float32(samples)
self._sample_rate = sample_rate
if self._samples.ndim >= 2:
self._samples = np.mean(self._samples, 1)
def __eq__(self, other):
"""Return whether two objects are equal."""
if not isinstance(other, type(self)):
return False
if self._sample_rate != other._sample_rate:
return False
if self._samples.shape != other._samples.shape:
return False
if np.any(self.samples != other._samples):
return False
return True
def __ne__(self, other):
"""Return whether two objects are unequal."""
return not self.__eq__(other)
def __str__(self):
"""Return human-readable representation of segment."""
return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
"rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate,
self.duration, self.rms_db))
@classmethod
def from_file(cls, file, infos=None):
"""Create audio segment from audio file.
Args:
filepath (str|file): Filepath or file object to audio file.
infos (TarLocalData, optional): tar2obj and tar2infos. Defaults to None.
Returns:
AudioSegment: Audio segment instance.
"""
if isinstance(file, str) and re.findall(r".seqbin_\d+$", file):
return cls.from_sequence_file(file)
elif isinstance(file, str) and file.startswith('tar:'):
return cls.from_file(subfile_from_tar(file, infos))
else:
samples, sample_rate = soundfile.read(file, dtype='float32')
return cls(samples, sample_rate)
@classmethod
def slice_from_file(cls, file, start=None, end=None):
"""Loads a small section of an audio without having to load
the entire file into the memory which can be incredibly wasteful.
:param file: Input audio filepath or file object.
:type file: str|file
:param start: Start time in seconds. If start is negative, it wraps
around from the end. If not provided, this function
reads from the very beginning.
:type start: float
:param end: End time in seconds. If end is negative, it wraps around
from the end. If not provided, the default behvaior is
to read to the end of the file.
:type end: float
:return: AudioSegment instance of the specified slice of the input
audio file.
:rtype: AudioSegment
:raise ValueError: If start or end is incorrectly set, e.g. out of
bounds in time.
"""
sndfile = soundfile.SoundFile(file)
sample_rate = sndfile.samplerate
duration = float(len(sndfile)) / sample_rate
start = 0. if start is None else start
end = duration if end is None else end
if start < 0.0:
start += duration
if end < 0.0:
end += duration
if start < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start)
if end < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end)
if start > end:
raise ValueError("The slice start position (%f s) is later than "
"the slice end position (%f s)." % (start, end))
if end > duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end, duration))
start_frame = int(start * sample_rate)
end_frame = int(end * sample_rate)
sndfile.seek(start_frame)
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
return cls(data, sample_rate)
@classmethod
def from_sequence_file(cls, filepath):
"""Create audio segment from sequence file. Sequence file is a binary
file containing a collection of multiple audio files, with several
header bytes in the head indicating the offsets of each audio byte data
chunk.
The format is:
4 bytes (int, version),
4 bytes (int, num of utterance),
4 bytes (int, bytes per header),
[bytes_per_header*(num_utterance+1)] bytes (offsets for each audio),
audio_bytes_data_of_1st_utterance,
audio_bytes_data_of_2nd_utterance,
......
Sequence file name must end with ".seqbin". And the filename of the 5th
utterance's audio file in sequence file "xxx.seqbin" must be
"xxx.seqbin_5", with "5" indicating the utterance index within this
sequence file (starting from 1).
:param filepath: Filepath of sequence file.
:type filepath: str
:return: Audio segment instance.
:rtype: AudioSegment
"""
# parse filepath
matches = re.match(r"(.+\.seqbin)_(\d+)", filepath)
if matches is None:
raise IOError("File type of %s is not supported" % filepath)
filename = matches.group(1)
fileno = int(matches.group(2))
# read headers
f = io.open(filename, mode='rb', encoding='utf8')
version = f.read(4)
num_utterances = struct.unpack("i", f.read(4))[0]
bytes_per_header = struct.unpack("i", f.read(4))[0]
header_bytes = f.read(bytes_per_header * (num_utterances + 1))
header = [
struct.unpack("i", header_bytes[bytes_per_header * i:
bytes_per_header * (i + 1)])[0]
for i in range(num_utterances + 1)
]
# read audio bytes
f.seek(header[fileno - 1])
audio_bytes = f.read(header[fileno] - header[fileno - 1])
f.close()
# create audio segment
try:
return cls.from_bytes(audio_bytes)
except Exception as e:
samples = np.frombuffer(audio_bytes, dtype='int16')
return cls(samples=samples, sample_rate=8000)
@classmethod
def from_bytes(cls, bytes):
"""Create audio segment from a byte string containing audio samples.
:param bytes: Byte string containing audio samples.
:type bytes: str
:return: Audio segment instance.
:rtype: AudioSegment
"""
samples, sample_rate = soundfile.read(
io.BytesIO(bytes), dtype='float32')
return cls(samples, sample_rate)
@classmethod
def from_pcm(cls, samples, sample_rate):
"""Create audio segment from a byte string containing audio samples.
:param samples: Audio samples [num_samples x num_channels].
:type samples: numpy.ndarray
:param sample_rate: Audio sample rate.
:type sample_rate: int
:return: Audio segment instance.
:rtype: AudioSegment
"""
return cls(samples, sample_rate)
@classmethod
def concatenate(cls, *segments):
"""Concatenate an arbitrary number of audio segments together.
:param *segments: Input audio segments to be concatenated.
:type *segments: tuple of AudioSegment
:return: Audio segment instance as concatenating results.
:rtype: AudioSegment
:raises ValueError: If the number of segments is zero, or if the
sample_rate of any segments does not match.
:raises TypeError: If any segment is not AudioSegment instance.
"""
# Perform basic sanity-checks.
if len(segments) == 0:
raise ValueError("No audio segments are given to concatenate.")
sample_rate = segments[0]._sample_rate
for seg in segments:
if sample_rate != seg._sample_rate:
raise ValueError("Can't concatenate segments with "
"different sample rates")
if not isinstance(seg, cls):
raise TypeError("Only audio segments of the same type "
"can be concatenated.")
samples = np.concatenate([seg.samples for seg in segments])
return cls(samples, sample_rate)
@classmethod
def make_silence(cls, duration, sample_rate):
"""Creates a silent audio segment of the given duration and sample rate.
:param duration: Length of silence in seconds.
:type duration: float
:param sample_rate: Sample rate.
:type sample_rate: float
:return: Silent AudioSegment instance of the given duration.
:rtype: AudioSegment
"""
samples = np.zeros(int(duration * sample_rate))
return cls(samples, sample_rate)
def to_wav_file(self, filepath, dtype='float32'):
"""Save audio segment to disk as wav file.
:param filepath: WAV filepath or file object to save the
audio segment.
:type filepath: str|file
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
'float32', 'float64'. Default is 'float32'.
:type dtype: str
:raises TypeError: If dtype is not supported.
"""
samples = self._convert_samples_from_float32(self._samples, dtype)
subtype_map = {
'int16': 'PCM_16',
'int32': 'PCM_32',
'float32': 'FLOAT',
'float64': 'DOUBLE'
}
soundfile.write(
filepath,
samples,
self._sample_rate,
format='WAV',
subtype=subtype_map[dtype])
def superimpose(self, other):
"""Add samples from another segment to those of this segment
(sample-wise addition, not segment concatenation).
Note that this is an in-place transformation.
:param other: Segment containing samples to be added in.
:type other: AudioSegments
:raise TypeError: If type of two segments don't match.
:raise ValueError: If the sample rates of the two segments are not
equal, or if the lengths of segments don't match.
"""
if isinstance(other, type(self)):
raise TypeError("Cannot add segments of different types: %s "
"and %s." % (type(self), type(other)))
if self._sample_rate != other._sample_rate:
raise ValueError("Sample rates must match to add segments.")
if len(self._samples) != len(other._samples):
raise ValueError("Segment lengths must match to add segments.")
self._samples += other._samples
def to_bytes(self, dtype='float32'):
"""Create a byte string containing the audio content.
:param dtype: Data type for export samples. Options: 'int16', 'int32',
'float32', 'float64'. Default is 'float32'.
:type dtype: str
:return: Byte string containing audio content.
:rtype: str
"""
samples = self._convert_samples_from_float32(self._samples, dtype)
return samples.tostring()
def to(self, dtype='int16'):
"""Create a `dtype` audio content.
:param dtype: Data type for export samples. Options: 'int16', 'int32',
'float32', 'float64'. Default is 'float32'.
:type dtype: str
:return: np.ndarray containing `dtype` audio content.
:rtype: str
"""
samples = self._convert_samples_from_float32(self._samples, dtype)
return samples
def gain_db(self, gain):
"""Apply gain in decibels to samples.
Note that this is an in-place transformation.
:param gain: Gain in decibels to apply to samples.
:type gain: float|1darray
"""
self._samples *= 10.**(gain / 20.)
def change_speed(self, speed_rate):
"""Change the audio speed by linear interpolation.
Note that this is an in-place transformation.
:param speed_rate: Rate of speed change:
speed_rate > 1.0, speed up the audio;
speed_rate = 1.0, unchanged;
speed_rate < 1.0, slow down the audio;
speed_rate <= 0.0, not allowed, raise ValueError.
:type speed_rate: float
:raises ValueError: If speed_rate <= 0.0.
"""
if speed_rate == 1.0:
return
if speed_rate <= 0:
raise ValueError("speed_rate should be greater than zero.")
# numpy
# old_length = self._samples.shape[0]
# new_length = int(old_length / speed_rate)
# old_indices = np.arange(old_length)
# new_indices = np.linspace(start=0, stop=old_length, num=new_length)
# self._samples = np.interp(new_indices, old_indices, self._samples)
# sox, slow
try:
import soxbindings as sox
except ImportError:
try:
from paddlespeech.s2t.utils import dynamic_pip_install
package = "sox"
dynamic_pip_install.install(package)
package = "soxbindings"
dynamic_pip_install.install(package)
import soxbindings as sox
except Exception:
raise RuntimeError(
"Can not install soxbindings on your system.")
tfm = sox.Transformer()
tfm.set_globals(multithread=False)
tfm.speed(speed_rate)
self._samples = tfm.build_array(
input_array=self._samples,
sample_rate_in=self._sample_rate).squeeze(-1).astype(
np.float32).copy()
def normalize(self, target_db=-20, max_gain_db=300.0):
"""Normalize audio to be of the desired RMS value in decibels.
Note that this is an in-place transformation.
:param target_db: Target RMS value in decibels. This value should be
less than 0.0 as 0.0 is full-scale audio.
:type target_db: float
:param max_gain_db: Max amount of gain in dB that can be applied for
normalization. This is to prevent nans when
attempting to normalize a signal consisting of
all zeros.
:type max_gain_db: float
:raises ValueError: If the required gain to normalize the segment to
the target_db value exceeds max_gain_db.
"""
gain = target_db - self.rms_db
if gain > max_gain_db:
raise ValueError(
"Unable to normalize segment to %f dB because the "
"the probable gain have exceeds max_gain_db (%f dB)" %
(target_db, max_gain_db))
self.gain_db(min(max_gain_db, target_db - self.rms_db))
def normalize_online_bayesian(self,
target_db,
prior_db,
prior_samples,
startup_delay=0.0):
"""Normalize audio using a production-compatible online/causal
algorithm. This uses an exponential likelihood and gamma prior to
make online estimates of the RMS even when there are very few samples.
Note that this is an in-place transformation.
:param target_db: Target RMS value in decibels.
:type target_bd: float
:param prior_db: Prior RMS estimate in decibels.
:type prior_db: float
:param prior_samples: Prior strength in number of samples.
:type prior_samples: float
:param startup_delay: Default 0.0s. If provided, this function will
accrue statistics for the first startup_delay
seconds before applying online normalization.
:type startup_delay: float
"""
# Estimate total RMS online.
startup_sample_idx = min(self.num_samples - 1,
int(self.sample_rate * startup_delay))
prior_mean_squared = 10.**(prior_db / 10.)
prior_sum_of_squares = prior_mean_squared * prior_samples
cumsum_of_squares = np.cumsum(self.samples**2)
sample_count = np.arange(self.num_samples) + 1
if startup_sample_idx > 0:
cumsum_of_squares[:startup_sample_idx] = \
cumsum_of_squares[startup_sample_idx]
sample_count[:startup_sample_idx] = \
sample_count[startup_sample_idx]
mean_squared_estimate = ((cumsum_of_squares + prior_sum_of_squares) /
(sample_count + prior_samples))
rms_estimate_db = 10 * np.log10(mean_squared_estimate)
# Compute required time-varying gain.
gain_db = target_db - rms_estimate_db
self.gain_db(gain_db)
def resample(self, target_sample_rate, filter='kaiser_best'):
"""Resample the audio to a target sample rate.
Note that this is an in-place transformation.
:param target_sample_rate: Target sample rate.
:type target_sample_rate: int
:param filter: The resampling filter to use one of {'kaiser_best',
'kaiser_fast'}.
:type filter: str
"""
self._samples = resampy.resample(
self.samples, self.sample_rate, target_sample_rate, filter=filter)
self._sample_rate = target_sample_rate
def pad_silence(self, duration, sides='both'):
"""Pad this audio sample with a period of silence.
Note that this is an in-place transformation.
:param duration: Length of silence in seconds to pad.
:type duration: float
:param sides: Position for padding:
'beginning' - adds silence in the beginning;
'end' - adds silence in the end;
'both' - adds silence in both the beginning and the end.
:type sides: str
:raises ValueError: If sides is not supported.
"""
if duration == 0.0:
return self
cls = type(self)
silence = self.make_silence(duration, self._sample_rate)
if sides == "beginning":
padded = cls.concatenate(silence, self)
elif sides == "end":
padded = cls.concatenate(self, silence)
elif sides == "both":
padded = cls.concatenate(silence, self, silence)
else:
raise ValueError("Unknown value for the sides %s" % sides)
self._samples = padded._samples
def shift(self, shift_ms):
"""Shift the audio in time. If `shift_ms` is positive, shift with time
advance; if negative, shift with time delay. Silence are padded to
keep the duration unchanged.
Note that this is an in-place transformation.
:param shift_ms: Shift time in millseconds. If positive, shift with
time advance; if negative; shift with time delay.
:type shift_ms: float
:raises ValueError: If shift_ms is longer than audio duration.
"""
if abs(shift_ms) / 1000.0 > self.duration:
raise ValueError("Absolute value of shift_ms should be smaller "
"than audio duration.")
shift_samples = int(shift_ms * self._sample_rate / 1000)
if shift_samples > 0:
# time advance
self._samples[:-shift_samples] = self._samples[shift_samples:]
self._samples[-shift_samples:] = 0
elif shift_samples < 0:
# time delay
self._samples[-shift_samples:] = self._samples[:shift_samples]
self._samples[:-shift_samples] = 0
def subsegment(self, start_sec=None, end_sec=None):
"""Cut the AudioSegment between given boundaries.
Note that this is an in-place transformation.
:param start_sec: Beginning of subsegment in seconds.
:type start_sec: float
:param end_sec: End of subsegment in seconds.
:type end_sec: float
:raise ValueError: If start_sec or end_sec is incorrectly set, e.g. out
of bounds in time.
"""
start_sec = 0.0 if start_sec is None else start_sec
end_sec = self.duration if end_sec is None else end_sec
if start_sec < 0.0:
start_sec = self.duration + start_sec
if end_sec < 0.0:
end_sec = self.duration + end_sec
if start_sec < 0.0:
raise ValueError("The slice start position (%f s) is out of "
"bounds." % start_sec)
if end_sec < 0.0:
raise ValueError("The slice end position (%f s) is out of bounds." %
end_sec)
if start_sec > end_sec:
raise ValueError("The slice start position (%f s) is later than "
"the end position (%f s)." % (start_sec, end_sec))
if end_sec > self.duration:
raise ValueError("The slice end position (%f s) is out of bounds "
"(> %f s)" % (end_sec, self.duration))
start_sample = int(round(start_sec * self._sample_rate))
end_sample = int(round(end_sec * self._sample_rate))
self._samples = self._samples[start_sample:end_sample]
def random_subsegment(self, subsegment_length, rng=None):
"""Cut the specified length of the audiosegment randomly.
Note that this is an in-place transformation.
:param subsegment_length: Subsegment length in seconds.
:type subsegment_length: float
:param rng: Random number generator state.
:type rng: random.Random
:raises ValueError: If the length of subsegment is greater than
the origineal segemnt.
"""
rng = random.Random() if rng is None else rng
if subsegment_length > self.duration:
raise ValueError("Length of subsegment must not be greater "
"than original segment.")
start_time = rng.uniform(0.0, self.duration - subsegment_length)
self.subsegment(start_time, start_time + subsegment_length)
def convolve(self, impulse_segment, allow_resample=False):
"""Convolve this audio segment with the given impulse segment.
Note that this is an in-place transformation.
:param impulse_segment: Impulse response segments.
:type impulse_segment: AudioSegment
:param allow_resample: Indicates whether resampling is allowed when
the impulse_segment has a different sample
rate from this signal.
:type allow_resample: bool
:raises ValueError: If the sample rate is not match between two
audio segments when resample is not allowed.
"""
if allow_resample and self.sample_rate != impulse_segment.sample_rate:
impulse_segment.resample(self.sample_rate)
if self.sample_rate != impulse_segment.sample_rate:
raise ValueError("Impulse segment's sample rate (%d Hz) is not "
"equal to base signal sample rate (%d Hz)." %
(impulse_segment.sample_rate, self.sample_rate))
samples = signal.fftconvolve(self.samples, impulse_segment.samples,
"full")
self._samples = samples
def convolve_and_normalize(self, impulse_segment, allow_resample=False):
"""Convolve and normalize the resulting audio segment so that it
has the same average power as the input signal.
Note that this is an in-place transformation.
:param impulse_segment: Impulse response segments.
:type impulse_segment: AudioSegment
:param allow_resample: Indicates whether resampling is allowed when
the impulse_segment has a different sample
rate from this signal.
:type allow_resample: bool
"""
target_db = self.rms_db
self.convolve(impulse_segment, allow_resample=allow_resample)
self.normalize(target_db)
def add_noise(self,
noise,
snr_dB,
allow_downsampling=False,
max_gain_db=300.0,
rng=None):
"""Add the given noise segment at a specific signal-to-noise ratio.
If the noise segment is longer than this segment, a random subsegment
of matching length is sampled from it and used instead.
Note that this is an in-place transformation.
:param noise: Noise signal to add.
:type noise: AudioSegment
:param snr_dB: Signal-to-Noise Ratio, in decibels.
:type snr_dB: float
:param allow_downsampling: Whether to allow the noise signal to be
downsampled to match the base signal sample
rate.
:type allow_downsampling: bool
:param max_gain_db: Maximum amount of gain to apply to noise signal
before adding it in. This is to prevent attempting
to apply infinite gain to a zero signal.
:type max_gain_db: float
:param rng: Random number generator state.
:type rng: None|random.Random
:raises ValueError: If the sample rate does not match between the two
audio segments when downsampling is not allowed, or
if the duration of noise segments is shorter than
original audio segments.
"""
rng = random.Random() if rng is None else rng
if allow_downsampling and noise.sample_rate > self.sample_rate:
noise = noise.resample(self.sample_rate)
if noise.sample_rate != self.sample_rate:
raise ValueError("Noise sample rate (%d Hz) is not equal to base "
"signal sample rate (%d Hz)." % (noise.sample_rate,
self.sample_rate))
if noise.duration < self.duration:
raise ValueError("Noise signal (%f sec) must be at least as long as"
" base signal (%f sec)." %
(noise.duration, self.duration))
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
noise_new = copy.deepcopy(noise)
noise_new.random_subsegment(self.duration, rng=rng)
noise_new.gain_db(noise_gain_db)
self.superimpose(noise_new)
@property
def samples(self):
"""Return audio samples.
:return: Audio samples.
:rtype: ndarray
"""
return self._samples.copy()
@property
def sample_rate(self):
"""Return audio sample rate.
:return: Audio sample rate.
:rtype: int
"""
return self._sample_rate
@property
def num_samples(self):
"""Return number of samples.
:return: Number of samples.
:rtype: int
"""
return self._samples.shape[0]
@property
def duration(self):
"""Return audio duration.
:return: Audio duration in seconds.
:rtype: float
"""
return self._samples.shape[0] / float(self._sample_rate)
@property
def rms_db(self):
"""Return root mean square energy of the audio in decibels.
:return: Root mean square energy in decibels.
:rtype: float
"""
# square root => multiply by 10 instead of 20 for dBs
mean_square = np.mean(self._samples**2)
return 10 * np.log10(mean_square)
def _convert_samples_to_float32(self, samples):
"""Convert sample type to float32.
Audio sample type is usually integer or float-point.
Integers will be scaled to [-1, 1] in float32.
"""
return convert_samples_to_float32(samples)
def _convert_samples_from_float32(self, samples, dtype):
"""Convert sample type from float32 to dtype.
Audio sample type is usually integer or float-point. For integer
type, float32 will be rescaled from [-1, 1] to the maximum range
supported by the integer type.
This is for writing a audio file.
"""
return convert_samples_from_float32(samples, dtype)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the data augmentation pipeline."""
import json
import os
from collections.abc import Sequence
from inspect import signature
import numpy as np
__all__ = ["AugmentationPipeline"]
class AugmentationPipeline():
"""Build a pre-processing pipeline with various augmentation models.Such a
data augmentation pipeline is oftern leveraged to augment the training
samples to make the model invariant to certain types of perturbations in the
real world, improving model's generalization ability.
The pipeline is built according the the augmentation configuration in json
string, e.g.
.. code-block::
[ {
"type": "noise",
"params": {"min_snr_dB": 10,
"max_snr_dB": 20,
"noise_manifest_path": "datasets/manifest.noise"},
"prob": 0.0
},
{
"type": "speed",
"params": {"min_speed_rate": 0.9,
"max_speed_rate": 1.1},
"prob": 1.0
},
{
"type": "shift",
"params": {"min_shift_ms": -5,
"max_shift_ms": 5},
"prob": 1.0
},
{
"type": "volume",
"params": {"min_gain_dBFS": -10,
"max_gain_dBFS": 10},
"prob": 0.0
},
{
"type": "bayesian_normal",
"params": {"target_db": -20,
"prior_db": -20,
"prior_samples": 100},
"prob": 0.0
}
]
This augmentation configuration inserts two augmentation models
into the pipeline, with one is VolumePerturbAugmentor and the other
SpeedPerturbAugmentor. "prob" indicates the probability of the current
augmentor to take effect. If "prob" is zero, the augmentor does not take
effect.
Params:
preprocess_conf(str): Augmentation configuration in `json file` or `json string`.
random_seed(int): Random seed.
Raises:
ValueError: If the augmentation json config is in incorrect format".
"""
SPEC_TYPES = {'specaug'}
def __init__(self, preprocess_conf: str, random_seed: int = 0):
self._rng = np.random.RandomState(random_seed)
self.conf = {'mode': 'sequential', 'process': []}
if preprocess_conf:
if os.path.isfile(preprocess_conf):
with open(preprocess_conf, 'r') as fin:
json_string = fin.read()
else:
json_string = preprocess_conf
process = json.loads(json_string)
self.conf['process'] += process
self._augmentors, self._rates = self._parse_pipeline_from('all')
self._audio_augmentors, self._audio_rates = self._parse_pipeline_from(
'audio')
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
'feature')
def __call__(self, xs, uttid_list=None, **kwargs):
if not isinstance(xs, Sequence):
is_batch = False
xs = [xs]
else:
is_batch = True
if isinstance(uttid_list, str):
uttid_list = [uttid_list for _ in range(len(xs))]
if self.conf.get("mode", "sequential") == "sequential":
for idx, (func, rate) in enumerate(
zip(self._augmentors, self._rates), 0):
if self._rng.uniform(0., 1.) >= rate:
continue
# Derive only the args which the func has
try:
param = signature(func).parameters
except ValueError:
# Some function, e.g. built-in function, are failed
param = {}
_kwargs = {k: v for k, v in kwargs.items() if k in param}
try:
if uttid_list is not None and "uttid" in param:
xs = [
func(x, u, **_kwargs)
for x, u in zip(xs, uttid_list)
]
else:
xs = [func(x, **_kwargs) for x in xs]
except Exception:
logger.fatal("Catch a exception from {}th func: {}".format(
idx, func))
raise
else:
raise NotImplementedError(
"Not supporting mode={}".format(self.conf["mode"]))
if is_batch:
return xs
else:
return xs[0]
def transform_audio(self, audio_segment):
"""Run the pre-processing pipeline for data augmentation.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to process.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
for augmentor, rate in zip(self._audio_augmentors, self._audio_rates):
if self._rng.uniform(0., 1.) < rate:
augmentor.transform_audio(audio_segment)
def transform_feature(self, spec_segment):
"""spectrogram augmentation.
Args:
spec_segment (np.ndarray): audio feature, (D, T).
"""
for augmentor, rate in zip(self._spec_augmentors, self._spec_rates):
if self._rng.uniform(0., 1.) < rate:
spec_segment = augmentor.transform_feature(spec_segment)
return spec_segment
def _parse_pipeline_from(self, aug_type='all'):
"""Parse the config json to build a augmentation pipelien."""
assert aug_type in ('audio', 'feature', 'all'), aug_type
audio_confs = []
feature_confs = []
all_confs = []
for config in self.conf['process']:
all_confs.append(config)
if config["type"] in self.SPEC_TYPES:
feature_confs.append(config)
else:
audio_confs.append(config)
if aug_type == 'audio':
aug_confs = audio_confs
elif aug_type == 'feature':
aug_confs = feature_confs
elif aug_type == 'all':
aug_confs = all_confs
else:
raise ValueError(f"Not support: {aug_type}")
augmentors = [
self._get_augmentor(config["type"], config["params"])
for config in aug_confs
]
rates = [config["prob"] for config in aug_confs]
return augmentors, rates
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the abstract base class for augmentation models."""
from abc import ABCMeta
from abc import abstractmethod
class AugmentorBase():
"""Abstract base class for augmentation model (augmentor) class.
All augmentor classes should inherit from this class, and implement the
following abstract methods.
"""
__metaclass__ = ABCMeta
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __call__(self, xs):
raise NotImplementedError("AugmentorBase: Not impl __call__")
@abstractmethod
def transform_audio(self, audio_segment):
"""Adds various effects to the input audio segment. Such effects
will augment the training data to make the model invariant to certain
types of perturbations in the real world, improving model's
generalization ability.
Note that this is an in-place transformation.
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
raise NotImplementedError("AugmentorBase: Not impl transform_audio")
@abstractmethod
def transform_feature(self, spec_segment):
"""Adds various effects to the input audo feature segment. Such effects
will augment the training data to make the model invariant to certain
types of time_mask or freq_mask in the real world, improving model's
generalization ability.
Args:
spec_segment (Spectrogram): Spectrogram segment to add effects to.
"""
raise NotImplementedError("AugmentorBase: Not impl transform_feature")
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .audio_featurizer import AudioFeaturizer # noqa: F401
from .speech_featurizer import SpeechFeaturizer
from .text_featurizer import TextFeaturizer
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the audio featurizer class."""
import numpy as np
from python_speech_features import delta
from python_speech_features import logfbank
from python_speech_features import mfcc
class AudioFeaturizer():
"""Audio featurizer, for extracting features from audio contents of
AudioSegment or SpeechSegment.
Currently, it supports feature types of linear spectrogram and mfcc.
:param spectrum_type: Specgram feature type. Options: 'linear'.
:type spectrum_type: str
:param stride_ms: Striding size (in milliseconds) for generating frames.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for generating frames.
:type window_ms: float
:param max_freq: When spectrum_type is 'linear', only FFT bins
corresponding to frequencies between [0, max_freq] are
returned; when spectrum_type is 'mfcc', max_feq is the
highest band edge of mel filters.
:types max_freq: None|float
:param target_sample_rate: Audio are resampled (if upsampling or
downsampling is allowed) to this before
extracting spectrogram features.
:type target_sample_rate: float
:param use_dB_normalization: Whether to normalize the audio to a certain
decibels before extracting the features.
:type use_dB_normalization: bool
:param target_dB: Target audio decibels for normalization.
:type target_dB: float
"""
def __init__(self,
spectrum_type: str = 'linear',
feat_dim: int = None,
delta_delta: bool = False,
stride_ms=10.0,
window_ms=20.0,
n_fft=None,
max_freq=None,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20,
dither=1.0):
self._spectrum_type = spectrum_type
# mfcc and fbank using `feat_dim`
self._feat_dim = feat_dim
# mfcc and fbank using `delta-delta`
self._delta_delta = delta_delta
self._stride_ms = stride_ms
self._window_ms = window_ms
self._max_freq = max_freq
self._target_sample_rate = target_sample_rate
self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB
self._fft_point = n_fft
self._dither = dither
def featurize(self,
audio_segment,
allow_downsampling=True,
allow_upsampling=True):
"""Extract audio features from AudioSegment or SpeechSegment.
:param audio_segment: Audio/speech segment to extract features from.
:type audio_segment: AudioSegment|SpeechSegment
:param allow_downsampling: Whether to allow audio downsampling before
featurizing.
:type allow_downsampling: bool
:param allow_upsampling: Whether to allow audio upsampling before
featurizing.
:type allow_upsampling: bool
:return: Spectrogram audio feature in 2darray.
:rtype: ndarray
:raises ValueError: If audio sample rate is not supported.
"""
# upsampling or downsampling
if ((audio_segment.sample_rate > self._target_sample_rate and
allow_downsampling) or
(audio_segment.sample_rate < self._target_sample_rate and
allow_upsampling)):
audio_segment.resample(self._target_sample_rate)
if audio_segment.sample_rate != self._target_sample_rate:
raise ValueError("Audio sample rate is not supported. "
"Turn allow_downsampling or allow up_sampling on.")
# decibel normalization
if self._use_dB_normalization:
audio_segment.normalize(target_db=self._target_dB)
# extract spectrogram
return self._compute_specgram(audio_segment)
@property
def stride_ms(self):
return self._stride_ms
@property
def feature_size(self):
"""audio feature size"""
feat_dim = 0
if self._spectrum_type == 'linear':
fft_point = self._window_ms if self._fft_point is None else self._fft_point
feat_dim = int(fft_point * (self._target_sample_rate / 1000) / 2 +
1)
elif self._spectrum_type == 'mfcc':
# mfcc, delta, delta-delta
feat_dim = int(self._feat_dim *
3) if self._delta_delta else int(self._feat_dim)
elif self._spectrum_type == 'fbank':
# fbank, delta, delta-delta
feat_dim = int(self._feat_dim *
3) if self._delta_delta else int(self._feat_dim)
else:
raise ValueError("Unknown spectrum_type %s. "
"Supported values: linear." % self._spectrum_type)
return feat_dim
def _compute_specgram(self, audio_segment):
"""Extract various audio features."""
sample_rate = audio_segment.sample_rate
if self._spectrum_type == 'linear':
samples = audio_segment.samples
return self._compute_linear_specgram(
samples,
sample_rate,
stride_ms=self._stride_ms,
window_ms=self._window_ms,
max_freq=self._max_freq)
elif self._spectrum_type == 'mfcc':
samples = audio_segment.to('int16')
return self._compute_mfcc(
samples,
sample_rate,
feat_dim=self._feat_dim,
stride_ms=self._stride_ms,
window_ms=self._window_ms,
max_freq=self._max_freq,
dither=self._dither,
delta_delta=self._delta_delta)
elif self._spectrum_type == 'fbank':
samples = audio_segment.to('int16')
return self._compute_fbank(
samples,
sample_rate,
feat_dim=self._feat_dim,
stride_ms=self._stride_ms,
window_ms=self._window_ms,
max_freq=self._max_freq,
dither=self._dither,
delta_delta=self._delta_delta)
else:
raise ValueError("Unknown spectrum_type %s. "
"Supported values: linear." % self._spectrum_type)
def _specgram_real(self, samples, window_size, stride_size, sample_rate):
"""Compute the spectrogram for samples from a real signal."""
# extract strided windows
truncate_size = (len(samples) - window_size) % stride_size
samples = samples[:len(samples) - truncate_size]
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
windows = np.lib.stride_tricks.as_strided(
samples, shape=nshape, strides=nstrides)
assert np.all(
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# window weighting, squared Fast Fourier Transform (fft), scaling
weighting = np.hanning(window_size)[:, None]
# https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html
fft = np.fft.rfft(windows * weighting, n=None, axis=0)
fft = np.absolute(fft)
fft = fft**2
scale = np.sum(weighting**2) * sample_rate
fft[1:-1, :] *= (2.0 / scale)
fft[(0, -1), :] /= scale
# prepare fft frequency list
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs
def _compute_linear_specgram(self,
samples,
sample_rate,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
eps=1e-14):
"""Compute the linear spectrogram from FFT energy.
Args:
samples ([type]): [description]
sample_rate ([type]): [description]
stride_ms (float, optional): [description]. Defaults to 10.0.
window_ms (float, optional): [description]. Defaults to 20.0.
max_freq ([type], optional): [description]. Defaults to None.
eps ([type], optional): [description]. Defaults to 1e-14.
Raises:
ValueError: [description]
ValueError: [description]
Returns:
np.ndarray: log spectrogram, (time, freq)
"""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
specgram, freqs = self._specgram_real(
samples,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
# (freq, time)
spec = np.log(specgram[:ind, :] + eps)
return np.transpose(spec)
def _concat_delta_delta(self, feat):
"""append delat, delta-delta feature.
Args:
feat (np.ndarray): (T, D)
Returns:
np.ndarray: feat with delta-delta, (T, 3*D)
"""
# Deltas
d_feat = delta(feat, 2)
# Deltas-Deltas
dd_feat = delta(feat, 2)
# concat above three features
concat_feat = np.concatenate((feat, d_feat, dd_feat), axis=1)
return concat_feat
def _compute_mfcc(self,
samples,
sample_rate,
feat_dim=13,
stride_ms=10.0,
window_ms=25.0,
max_freq=None,
dither=1.0,
delta_delta=True):
"""Compute mfcc from samples.
Args:
samples (np.ndarray, np.int16): the audio signal from which to compute features.
sample_rate (float): the sample rate of the signal we are working with, in Hz.
feat_dim (int): the number of cepstrum to return, default 13.
stride_ms (float, optional): stride length in ms. Defaults to 10.0.
window_ms (float, optional): window length in ms. Defaults to 25.0.
max_freq ([type], optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None.
delta_delta (bool, optional): Whether with delta delta. Defaults to False.
Raises:
ValueError: max_freq > samplerate/2
ValueError: stride_ms > window_ms
Returns:
np.ndarray: mfcc feature, (D, T).
"""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
# compute the 13 cepstral coefficients, and the first one is replaced
# by log(frame energy), (T, D)
mfcc_feat = mfcc(
signal=samples,
samplerate=sample_rate,
winlen=0.001 * window_ms,
winstep=0.001 * stride_ms,
numcep=feat_dim,
nfilt=23,
nfft=512,
lowfreq=20,
highfreq=max_freq,
dither=dither,
remove_dc_offset=True,
preemph=0.97,
ceplifter=22,
useEnergy=True,
winfunc='povey')
if delta_delta:
mfcc_feat = self._concat_delta_delta(mfcc_feat)
return mfcc_feat
def _compute_fbank(self,
samples,
sample_rate,
feat_dim=40,
stride_ms=10.0,
window_ms=25.0,
max_freq=None,
dither=1.0,
delta_delta=False):
"""Compute logfbank from samples.
Args:
samples (np.ndarray, np.int16): the audio signal from which to compute features. Should be an N*1 array
sample_rate (float): the sample rate of the signal we are working with, in Hz.
feat_dim (int): the number of cepstrum to return, default 13.
stride_ms (float, optional): stride length in ms. Defaults to 10.0.
window_ms (float, optional): window length in ms. Defaults to 20.0.
max_freq (float, optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None.
delta_delta (bool, optional): Whether with delta delta. Defaults to False.
Raises:
ValueError: max_freq > samplerate/2
ValueError: stride_ms > window_ms
Returns:
np.ndarray: mfcc feature, (D, T).
"""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must not be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
# (T, D)
fbank_feat = logfbank(
signal=samples,
samplerate=sample_rate,
winlen=0.001 * window_ms,
winstep=0.001 * stride_ms,
nfilt=feat_dim,
nfft=512,
lowfreq=20,
highfreq=max_freq,
dither=dither,
remove_dc_offset=True,
preemph=0.97,
wintype='povey')
if delta_delta:
fbank_feat = self._concat_delta_delta(fbank_feat)
return fbank_feat
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the speech featurizer class."""
from .audio_featurizer import AudioFeaturizer
from .text_featurizer import TextFeaturizer
class SpeechFeaturizer():
def __init__(self,
unit_type,
vocab_filepath,
spm_model_prefix=None,
spectrum_type='linear',
feat_dim=None,
delta_delta=False,
stride_ms=10.0,
window_ms=20.0,
n_fft=None,
max_freq=None,
target_sample_rate=16000,
use_dB_normalization=True,
target_dB=-20,
dither=1.0,
maskctc=False):
self.stride_ms = stride_ms
self.window_ms = window_ms
self.audio_feature = AudioFeaturizer(
spectrum_type=spectrum_type,
feat_dim=feat_dim,
delta_delta=delta_delta,
stride_ms=stride_ms,
window_ms=window_ms,
n_fft=n_fft,
max_freq=max_freq,
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB,
dither=dither)
self.feature_size = self.audio_feature.feature_size
self.text_feature = TextFeaturizer(
unit_type=unit_type,
vocab=vocab_filepath,
spm_model_prefix=spm_model_prefix,
maskctc=maskctc)
self.vocab_size = self.text_feature.vocab_size
def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment.
1. For audio parts, extract the audio features.
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
Args:
speech_segment (SpeechSegment): Speech segment to extract features from.
keep_transcription_text (bool): True, keep transcript text, False, token ids
Returns:
tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices.
"""
spec_feature = self.audio_feature.featurize(speech_segment)
if keep_transcription_text:
return spec_feature, speech_segment.transcript
if speech_segment.has_token:
text_ids = speech_segment.token_ids
else:
text_ids = self.text_feature.featurize(speech_segment.transcript)
return spec_feature, text_ids
def text_featurize(self, text, keep_transcription_text):
"""Extract features for speech segment.
1. For audio parts, extract the audio features.
2. For transcript parts, keep the original text or convert text string
to a list of token indices in char-level.
Args:
text (str): text.
keep_transcription_text (bool): True, keep transcript text, False, token ids
Returns:
(str|List[int]): text, or list of token indices.
"""
if keep_transcription_text:
return text
text_ids = self.text_feature.featurize(text)
return text_ids
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the text featurizer class."""
from typing import Union
import sentencepiece as spm
from ..utility import BLANK, EOS, MASKCTC, SOS, SPACE, UNK, load_dict
__all__ = ["TextFeaturizer"]
class TextFeaturizer():
def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False):
"""Text featurizer, for processing or extracting features from text.
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the
given vocabulary file.
Args:
unit_type (str): unit type, e.g. char, word, spm
vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list.
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
"""
assert unit_type in ('char', 'spm', 'word')
self.unit_type = unit_type
self.unk = UNK
self.maskctc = maskctc
if vocab:
self.vocab_dict, self._id2token, self.vocab_list, \
self.unk_id, self.eos_id, \
self.blank_id = self._load_vocabulary_from_file(vocab, maskctc)
self.vocab_size = len(self.vocab_list)
else:
print("TextFeaturizer: not have vocab file or vocab list.")
if unit_type == 'spm':
spm_model = spm_model_prefix + '.model'
self.sp = spm.SentencePieceProcessor()
self.sp.Load(spm_model)
def tokenize(self, text, replace_space=True):
if self.unit_type == 'char':
tokens = self.char_tokenize(text, replace_space)
elif self.unit_type == 'word':
tokens = self.word_tokenize(text)
else: # spm
tokens = self.spm_tokenize(text)
return tokens
def detokenize(self, tokens):
if self.unit_type == 'char':
text = self.char_detokenize(tokens)
elif self.unit_type == 'word':
text = self.word_detokenize(tokens)
else: # spm
text = self.spm_detokenize(tokens)
return text
def featurize(self, text):
"""Convert text string to a list of token indices.
Args:
text (str): Text to process.
Returns:
List[int]: List of token indices.
"""
tokens = self.tokenize(text)
ids = []
for token in tokens:
if token not in self.vocab_dict:
token = self.unk
ids.append(self.vocab_dict[token])
return ids
def defeaturize(self, idxs):
"""Convert a list of token indices to text string,
ignore index after eos_id.
Args:
idxs (List[int]): List of token indices.
Returns:
str: Text.
"""
tokens = []
for idx in idxs:
if idx == self.eos_id:
break
tokens.append(self._id2token[idx])
text = self.detokenize(tokens)
return text
def char_tokenize(self, text, replace_space=True):
"""Character tokenizer.
Args:
text (str): text string.
replace_space (bool): False only used by build_vocab.py.
Returns:
List[str]: tokens.
"""
text = text.strip()
if replace_space:
text_list = [SPACE if item == " " else item for item in list(text)]
else:
text_list = list(text)
return text_list
def char_detokenize(self, tokens):
"""Character detokenizer.
Args:
tokens (List[str]): tokens.
Returns:
str: text string.
"""
tokens = [t.replace(SPACE, " ") for t in tokens]
return "".join(tokens)
def word_tokenize(self, text):
"""Word tokenizer, separate by <space>."""
return text.strip().split()
def word_detokenize(self, tokens):
"""Word detokenizer, separate by <space>."""
return " ".join(tokens)
def spm_tokenize(self, text):
"""spm tokenize.
Args:
text (str): text string.
Returns:
List[str]: sentence pieces str code
"""
stats = {"num_empty": 0, "num_filtered": 0}
def valid(line):
return True
def encode(l):
return self.sp.EncodeAsPieces(l)
def encode_line(line):
line = line.strip()
if len(line) > 0:
line = encode(line)
if valid(line):
return line
else:
stats["num_filtered"] += 1
else:
stats["num_empty"] += 1
return None
enc_line = encode_line(text)
return enc_line
def spm_detokenize(self, tokens, input_format='piece'):
"""spm detokenize.
Args:
ids (List[str]): tokens.
Returns:
str: text
"""
if input_format == "piece":
def decode(l):
return "".join(self.sp.DecodePieces(l))
elif input_format == "id":
def decode(l):
return "".join(self.sp.DecodeIds(l))
return decode(tokens)
def _load_vocabulary_from_file(self, vocab: Union[str, list],
maskctc: bool):
"""Load vocabulary from file."""
if isinstance(vocab, list):
vocab_list = vocab
else:
vocab_list = load_dict(vocab, maskctc)
assert vocab_list is not None
id2token = dict(
[(idx, token) for (idx, token) in enumerate(vocab_list)])
token2id = dict(
[(token, idx) for (idx, token) in enumerate(vocab_list)])
blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1
unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1
eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1
return token2id, id2token, vocab_list, unk_id, eos_id, blank_id
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains feature normalizers."""
import numpy as np
from .utility import load_cmvn
__all__ = ["FeatureNormalizer"]
class FeatureNormalizer(object):
def __init__(self, mean_std_filepath):
mean_std = mean_std_filepath
self._read_mean_std_from_file(mean_std)
def apply(self, features):
"""Normalize features to be of zero mean and unit stddev.
:param features: Input features to be normalized.
:type features: ndarray, shape (T, D)
:param eps: added to stddev to provide numerical stablibity.
:type eps: float
:return: Normalized features.
:rtype: ndarray
"""
return (features - self._mean) * self._istd
def _read_mean_std_from_file(self, mean_std, eps=1e-20):
"""Load mean and std from file."""
if isinstance(mean_std, list):
mean = mean_std[0]['cmvn_stats']['mean']
istd = mean_std[0]['cmvn_stats']['istd']
else:
filetype = mean_std.split(".")[-1]
mean, istd = load_cmvn(mean_std, filetype=filetype)
self._mean = np.expand_dims(mean, axis=0)
self._istd = np.expand_dims(istd, axis=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment