import torch from torch.utils.data import DataLoader, Dataset from transformers import SpeechT5Processor, SpeechT5ForSpeechToText from datasets import load_dataset, Audio import time import argparse import os current_directory = os.path.dirname(os.path.realpath(__file__)) class AudioDataset(Dataset): def __init__(self, dataset, processor, sampling_rate): self.dataset = dataset self.processor = processor self.sampling_rate = sampling_rate def __len__(self): return len(self.dataset) def __getitem__(self, idx): audio = self.dataset[idx]["audio"]["array"] sample = self.processor(audio=audio, sampling_rate=self.sampling_rate, return_tensors="pt") return {"input_values": sample["input_values"].squeeze(0)} # 移除多余的维度 def collate_fn(batch): # 自动填充序列,确保每个批次中的音频长度相同 input_values = [item["input_values"] for item in batch] input_values = torch.nn.utils.rnn.pad_sequence(input_values, batch_first=True) return {"input_values": input_values} def main(opt): # 加载数据集 dataset = load_dataset(opt.dataset_script, 'clean', cache_dir=opt.dataset_dir, split="test") dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) # 确保音频数据格式正确 # 获取采样率 sampling_rate = 16000 # 初始化处理器和模型 processor = SpeechT5Processor.from_pretrained(opt.model_path) model = SpeechT5ForSpeechToText.from_pretrained(opt.model_path).to('cuda') # 将模型移动到GPU上 # 设置批次大小 batch_size = opt.batch_size # 创建数据加载器 dataloader = DataLoader( AudioDataset(dataset, processor, sampling_rate), batch_size=batch_size, shuffle=False, collate_fn=collate_fn ) # 进行推理 all_transcriptions = [] with torch.no_grad(): for batch in dataloader: size = batch['input_values'].size() inputs = {k: v.to('cuda') for k, v in batch.items()} # 将输入数据移动到GPU上 #开始计时 start = time.time() predicted_ids = model.generate(**inputs, max_length=400) transcription_batch = processor.batch_decode(predicted_ids, skip_special_tokens=True) #结束计时 end = time.time() all_transcriptions.extend(transcription_batch) break resume_time = end - start samples_per_second = batch_size / resume_time # 输出结果 # for idx, transcription in enumerate(all_transcriptions): # print(f"Sample {idx}: {transcription}") print(f"resume_time: {resume_time: .2f}, \nsamples_per_second: {samples_per_second: .2f}") def parse_opt(known=False): parser = argparse.ArgumentParser() parser.add_argument('-m', '--model-path', type=str, default="/public/home/changhl/py_project/speecht5_pytorch/speecht5_asr", help="initial model path") parser.add_argument('-ds', '--dataset_script', type=str, default=os.path.join(current_directory, "librispeech_asr_test.py"), help="speech scriot") parser.add_argument('-dr', '--dataset_dir', type=str, default=current_directory, help="speech scriot") parser.add_argument('-b', '--batch_size', type=int, default=32, help="the batch_size of speech") opt = parser.parse_known_args()[0] if known else parser.parse_args() return opt if __name__ == "__main__": main(parse_opt())