Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SpeechT5_pytorch
Commits
a0faaefd
Commit
a0faaefd
authored
Sep 11, 2024
by
“change”
Browse files
add benchmark
parent
c90f7a12
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
238 additions
and
1 deletion
+238
-1
README.md
README.md
+10
-1
benchmark.py
benchmark.py
+97
-0
librispeech_asr_test.py
librispeech_asr_test.py
+131
-0
No files found.
README.md
View file @
a0faaefd
...
...
@@ -202,6 +202,15 @@ python speech_asr.py -hip 7 -m model/speecht5_asr -is ../data/librispeech/dev-cl
-
输入:./data/librispeech/dev-clean/1272/128104/1272-128104-0000.flac
-
输出:./res/asr.txt
#### benchmark 计算
```
cd speecht5_pytorch
python benchmark.py -m model/speecht5_asr -ds librispeech_asr_test.py -b 32
```
-
-m: asr模型路径
-
-ds: 测试数据的处理脚本,默认为同级目录下的librispeech_asr_test.py
-
-dr: 数据集路径,默认为speech_pytorch
-
-b: 测试batch_size,默认为32(最大为128)
## 应用场景
### 算法分类
...
...
benchmark.py
0 → 100644
View file @
a0faaefd
import
torch
from
torch.utils.data
import
DataLoader
,
Dataset
from
transformers
import
SpeechT5Processor
,
SpeechT5ForSpeechToText
from
datasets
import
load_dataset
,
Audio
import
time
import
argparse
import
os
current_directory
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
class
AudioDataset
(
Dataset
):
def
__init__
(
self
,
dataset
,
processor
,
sampling_rate
):
self
.
dataset
=
dataset
self
.
processor
=
processor
self
.
sampling_rate
=
sampling_rate
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
__getitem__
(
self
,
idx
):
audio
=
self
.
dataset
[
idx
][
"audio"
][
"array"
]
sample
=
self
.
processor
(
audio
=
audio
,
sampling_rate
=
self
.
sampling_rate
,
return_tensors
=
"pt"
)
return
{
"input_values"
:
sample
[
"input_values"
].
squeeze
(
0
)}
# 移除多余的维度
def
collate_fn
(
batch
):
# 自动填充序列,确保每个批次中的音频长度相同
input_values
=
[
item
[
"input_values"
]
for
item
in
batch
]
input_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
input_values
,
batch_first
=
True
)
return
{
"input_values"
:
input_values
}
def
main
(
opt
):
# 加载数据集
dataset
=
load_dataset
(
opt
.
dataset_script
,
'clean'
,
cache_dir
=
opt
.
dataset_dir
,
split
=
"test"
)
dataset
=
dataset
.
cast_column
(
"audio"
,
Audio
(
sampling_rate
=
16000
))
# 确保音频数据格式正确
# 获取采样率
sampling_rate
=
16000
# 初始化处理器和模型
processor
=
SpeechT5Processor
.
from_pretrained
(
opt
.
model_path
)
model
=
SpeechT5ForSpeechToText
.
from_pretrained
(
opt
.
model_path
).
to
(
'cuda'
)
# 将模型移动到GPU上
# 设置批次大小
batch_size
=
opt
.
batch_size
# 创建数据加载器
dataloader
=
DataLoader
(
AudioDataset
(
dataset
,
processor
,
sampling_rate
),
batch_size
=
batch_size
,
shuffle
=
False
,
collate_fn
=
collate_fn
)
# 进行推理
all_transcriptions
=
[]
with
torch
.
no_grad
():
for
batch
in
dataloader
:
size
=
batch
[
'input_values'
].
size
()
inputs
=
{
k
:
v
.
to
(
'cuda'
)
for
k
,
v
in
batch
.
items
()}
# 将输入数据移动到GPU上
#开始计时
start
=
time
.
time
()
predicted_ids
=
model
.
generate
(
**
inputs
,
max_length
=
400
)
transcription_batch
=
processor
.
batch_decode
(
predicted_ids
,
skip_special_tokens
=
True
)
#结束计时
end
=
time
.
time
()
all_transcriptions
.
extend
(
transcription_batch
)
break
resume_time
=
end
-
start
samples_per_second
=
batch_size
/
resume_time
# 输出结果
# for idx, transcription in enumerate(all_transcriptions):
# print(f"Sample {idx}: {transcription}")
print
(
f
"resume_time:
{
resume_time
:
.
2
f
}
,
\n
samples_per_second:
{
samples_per_second
:
.
2
f
}
"
)
def
parse_opt
(
known
=
False
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'-m'
,
'--model-path'
,
type
=
str
,
default
=
"/public/home/changhl/py_project/speecht5_pytorch/speecht5_asr"
,
help
=
"initial model path"
)
parser
.
add_argument
(
'-ds'
,
'--dataset_script'
,
type
=
str
,
default
=
os
.
path
.
join
(
current_directory
,
"librispeech_asr_test.py"
),
help
=
"speech scriot"
)
parser
.
add_argument
(
'-dr'
,
'--dataset_dir'
,
type
=
str
,
default
=
current_directory
,
help
=
"speech scriot"
)
parser
.
add_argument
(
'-b'
,
'--batch_size'
,
type
=
int
,
default
=
32
,
help
=
"the batch_size of speech"
)
opt
=
parser
.
parse_known_args
()[
0
]
if
known
else
parser
.
parse_args
()
return
opt
if
__name__
==
"__main__"
:
main
(
parse_opt
())
\ No newline at end of file
librispeech_asr_test.py
0 → 100644
View file @
a0faaefd
# coding=utf-8
# Copyright 2021 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Librispeech automatic speech recognition dataset."""
from
__future__
import
absolute_import
,
division
,
print_function
import
glob
import
os
import
datasets
_CITATION
=
"""
\
@inproceedings{panayotov2015librispeech,
title={Librispeech: an ASR corpus based on public domain audio books},
author={Panayotov, Vassil and Chen, Guoguo and Povey, Daniel and Khudanpur, Sanjeev},
booktitle={Acoustics, Speech and Signal Processing (ICASSP), 2015 IEEE International Conference on},
pages={5206--5210},
year={2015},
organization={IEEE}
}
"""
_DESCRIPTION
=
"""
\
LibriSpeech is a corpus of approximately 1000 hours of read English speech with sampling rate of 16 kHz,
prepared by Vassil Panayotov with the assistance of Daniel Povey. The data is derived from read
audiobooks from the LibriVox project, and has been carefully segmented and aligned.
Note that in order to limit the required storage for preparing this dataset, the audio
is stored in the .flac format and is not converted to a float32 array. To convert, the audio
file to a float32 array, please make use of the `.map()` function as follows:
```python
import soundfile as sf
def map_to_array(batch):
speech_array, _ = sf.read(batch["file"])
batch["speech"] = speech_array
return batch
dataset = dataset.map(map_to_array, remove_columns=["file"])
```
"""
_URL
=
"http://www.openslr.org/12"
_DL_URL
=
"https://www.openslr.org/resources/12/"
_DL_URLS
=
{
"clean"
:
{
"test"
:
_DL_URL
+
"test-clean.tar.gz"
,
}
}
class
LibrispeechASRConfig
(
datasets
.
BuilderConfig
):
"""BuilderConfig for LibriSpeechASR."""
def
__init__
(
self
,
**
kwargs
):
"""
Args:
data_dir: `string`, the path to the folder containing the files in the
downloaded .tar
citation: `string`, citation for the data set
url: `string`, url for information about the data set
**kwargs: keyword arguments forwarded to super.
"""
super
(
LibrispeechASRConfig
,
self
).
__init__
(
version
=
datasets
.
Version
(
"2.1.0"
,
""
),
**
kwargs
)
class
LibrispeechASR
(
datasets
.
GeneratorBasedBuilder
):
"""Librispeech dataset."""
BUILDER_CONFIGS
=
[
LibrispeechASRConfig
(
name
=
"clean"
,
description
=
"'Clean' speech."
),
LibrispeechASRConfig
(
name
=
"other"
,
description
=
"'Other', more challenging, speech."
),
]
def
_info
(
self
):
return
datasets
.
DatasetInfo
(
description
=
_DESCRIPTION
,
features
=
datasets
.
Features
(
{
"file"
:
datasets
.
Value
(
"string"
),
"audio"
:
datasets
.
features
.
Audio
(
sampling_rate
=
16_000
),
"text"
:
datasets
.
Value
(
"string"
),
"speaker_id"
:
datasets
.
Value
(
"int64"
),
"chapter_id"
:
datasets
.
Value
(
"int64"
),
"id"
:
datasets
.
Value
(
"string"
),
}
),
supervised_keys
=
(
"speech"
,
"text"
),
homepage
=
_URL
,
citation
=
_CITATION
,
)
def
_split_generators
(
self
,
dl_manager
):
archive_path
=
dl_manager
.
download_and_extract
(
_DL_URLS
[
self
.
config
.
name
])
return
[
datasets
.
SplitGenerator
(
name
=
datasets
.
Split
.
TEST
,
gen_kwargs
=
{
"archive_path"
:
archive_path
[
"test"
],
"split_name"
:
f
"test-
{
self
.
config
.
name
}
"
}),
]
def
_generate_examples
(
self
,
archive_path
,
split_name
):
"""Generate examples from a Librispeech archive_path."""
transcripts_glob
=
os
.
path
.
join
(
archive_path
,
"LibriSpeech"
,
split_name
,
"*/*/*.txt"
)
for
transcript_file
in
glob
.
glob
(
transcripts_glob
):
path
=
os
.
path
.
dirname
(
transcript_file
)
with
open
(
os
.
path
.
join
(
path
,
transcript_file
))
as
f
:
for
line
in
f
:
line
=
line
.
strip
()
key
,
transcript
=
line
.
split
(
" "
,
1
)
audio_file
=
f
"
{
key
}
.flac"
speaker_id
,
chapter_id
=
[
int
(
el
)
for
el
in
key
.
split
(
"-"
)[:
2
]]
example
=
{
"id"
:
key
,
"speaker_id"
:
speaker_id
,
"chapter_id"
:
chapter_id
,
"file"
:
os
.
path
.
join
(
path
,
audio_file
),
"audio"
:
os
.
path
.
join
(
path
,
audio_file
),
"text"
:
transcript
,
}
yield
key
,
example
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment