Commit a0faaefd authored by “change”'s avatar “change”
Browse files

add benchmark

parent c90f7a12
......@@ -202,6 +202,15 @@ python speech_asr.py -hip 7 -m model/speecht5_asr -is ../data/librispeech/dev-cl
- 输入:./data/librispeech/dev-clean/1272/128104/1272-128104-0000.flac
- 输出:./res/asr.txt
#### benchmark 计算
```
cd speecht5_pytorch
python benchmark.py -m model/speecht5_asr -ds librispeech_asr_test.py -b 32
```
- -m: asr模型路径
- -ds: 测试数据的处理脚本,默认为同级目录下的librispeech_asr_test.py
- -dr: 数据集路径,默认为speech_pytorch
- -b: 测试batch_size,默认为32(最大为128)
## 应用场景
### 算法分类
......
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import SpeechT5Processor, SpeechT5ForSpeechToText
from datasets import load_dataset, Audio
import time
import argparse
import os
current_directory = os.path.dirname(os.path.realpath(__file__))
class AudioDataset(Dataset):
def __init__(self, dataset, processor, sampling_rate):
self.dataset = dataset
self.processor = processor
self.sampling_rate = sampling_rate
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
audio = self.dataset[idx]["audio"]["array"]
sample = self.processor(audio=audio, sampling_rate=self.sampling_rate, return_tensors="pt")
return {"input_values": sample["input_values"].squeeze(0)} # 移除多余的维度
def collate_fn(batch):
# 自动填充序列,确保每个批次中的音频长度相同
input_values = [item["input_values"] for item in batch]
input_values = torch.nn.utils.rnn.pad_sequence(input_values, batch_first=True)
return {"input_values": input_values}
def main(opt):
# 加载数据集
dataset = load_dataset(opt.dataset_script, 'clean',
cache_dir=opt.dataset_dir, split="test")
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) # 确保音频数据格式正确
# 获取采样率
sampling_rate = 16000
# 初始化处理器和模型
processor = SpeechT5Processor.from_pretrained(opt.model_path)
model = SpeechT5ForSpeechToText.from_pretrained(opt.model_path).to('cuda') # 将模型移动到GPU上
# 设置批次大小
batch_size = opt.batch_size
# 创建数据加载器
dataloader = DataLoader(
AudioDataset(dataset, processor, sampling_rate),
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn
)
# 进行推理
all_transcriptions = []
with torch.no_grad():
for batch in dataloader:
size = batch['input_values'].size()
inputs = {k: v.to('cuda') for k, v in batch.items()} # 将输入数据移动到GPU上
#开始计时
start = time.time()
predicted_ids = model.generate(**inputs, max_length=400)
transcription_batch = processor.batch_decode(predicted_ids, skip_special_tokens=True)
#结束计时
end = time.time()
all_transcriptions.extend(transcription_batch)
break
resume_time = end - start
samples_per_second = batch_size / resume_time
# 输出结果
# for idx, transcription in enumerate(all_transcriptions):
# print(f"Sample {idx}: {transcription}")
print(f"resume_time: {resume_time: .2f}, \nsamples_per_second: {samples_per_second: .2f}")
def parse_opt(known=False):
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model-path', type=str,
default="/public/home/changhl/py_project/speecht5_pytorch/speecht5_asr", help="initial model path")
parser.add_argument('-ds', '--dataset_script', type=str, default=os.path.join(current_directory, "librispeech_asr_test.py"), help="speech scriot")
parser.add_argument('-dr', '--dataset_dir', type=str, default=current_directory, help="speech scriot")
parser.add_argument('-b', '--batch_size', type=int, default=32, help="the batch_size of speech")
opt = parser.parse_known_args()[0] if known else parser.parse_args()
return opt
if __name__ == "__main__":
main(parse_opt())
\ No newline at end of file
# coding=utf-8
# Copyright 2021 The TensorFlow Datasets Authors and the HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Librispeech automatic speech recognition dataset."""
from __future__ import absolute_import, division, print_function
import glob
import os
import datasets
_CITATION = """\
@inproceedings{panayotov2015librispeech,
title={Librispeech: an ASR corpus based on public domain audio books},
author={Panayotov, Vassil and Chen, Guoguo and Povey, Daniel and Khudanpur, Sanjeev},
booktitle={Acoustics, Speech and Signal Processing (ICASSP), 2015 IEEE International Conference on},
pages={5206--5210},
year={2015},
organization={IEEE}
}
"""
_DESCRIPTION = """\
LibriSpeech is a corpus of approximately 1000 hours of read English speech with sampling rate of 16 kHz,
prepared by Vassil Panayotov with the assistance of Daniel Povey. The data is derived from read
audiobooks from the LibriVox project, and has been carefully segmented and aligned.
Note that in order to limit the required storage for preparing this dataset, the audio
is stored in the .flac format and is not converted to a float32 array. To convert, the audio
file to a float32 array, please make use of the `.map()` function as follows:
```python
import soundfile as sf
def map_to_array(batch):
speech_array, _ = sf.read(batch["file"])
batch["speech"] = speech_array
return batch
dataset = dataset.map(map_to_array, remove_columns=["file"])
```
"""
_URL = "http://www.openslr.org/12"
_DL_URL = "https://www.openslr.org/resources/12/"
_DL_URLS = {
"clean": {
"test": _DL_URL + "test-clean.tar.gz",
}
}
class LibrispeechASRConfig(datasets.BuilderConfig):
"""BuilderConfig for LibriSpeechASR."""
def __init__(self, **kwargs):
"""
Args:
data_dir: `string`, the path to the folder containing the files in the
downloaded .tar
citation: `string`, citation for the data set
url: `string`, url for information about the data set
**kwargs: keyword arguments forwarded to super.
"""
super(LibrispeechASRConfig, self).__init__(version=datasets.Version("2.1.0", ""), **kwargs)
class LibrispeechASR(datasets.GeneratorBasedBuilder):
"""Librispeech dataset."""
BUILDER_CONFIGS = [
LibrispeechASRConfig(name="clean", description="'Clean' speech."),
LibrispeechASRConfig(name="other", description="'Other', more challenging, speech."),
]
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"file": datasets.Value("string"),
"audio": datasets.features.Audio(sampling_rate=16_000),
"text": datasets.Value("string"),
"speaker_id": datasets.Value("int64"),
"chapter_id": datasets.Value("int64"),
"id": datasets.Value("string"),
}
),
supervised_keys=("speech", "text"),
homepage=_URL,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
archive_path = dl_manager.download_and_extract(_DL_URLS[self.config.name])
return [
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"archive_path": archive_path["test"], "split_name": f"test-{self.config.name}"}),
]
def _generate_examples(self, archive_path, split_name):
"""Generate examples from a Librispeech archive_path."""
transcripts_glob = os.path.join(archive_path, "LibriSpeech", split_name, "*/*/*.txt")
for transcript_file in glob.glob(transcripts_glob):
path = os.path.dirname(transcript_file)
with open(os.path.join(path, transcript_file)) as f:
for line in f:
line = line.strip()
key, transcript = line.split(" ", 1)
audio_file = f"{key}.flac"
speaker_id, chapter_id = [int(el) for el in key.split("-")[:2]]
example = {
"id": key,
"speaker_id": speaker_id,
"chapter_id": chapter_id,
"file": os.path.join(path, audio_file),
"audio": os.path.join(path, audio_file),
"text": transcript,
}
yield key, example
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment