from transformers import SpeechT5Processor, SpeechT5ForSpeechToText from transformers import logging from datasets import load_dataset import torch import argparse import librosa import numpy as np import os def parse_opt(known=False): parser = argparse.ArgumentParser() parser.add_argument('-hip', '--hip-device', type=int, default=0, help="initial hip devices") parser.add_argument('-m', '--model-path', type=str, default="", help="initial model path") parser.add_argument('-is', '--input_speech', type=str, default="Autumn, the season of change.", help="input speech") parser.add_argument('-res', '--result_path', type=str, default="../res", help="the path to save wav file") opt = parser.parse_known_args()[0] if known else parser.parse_args() return opt def main(opt): device = torch.device(f"cuda:{int(opt.hip_device)}") print(f"Using device: {device}") # 初始化speechT5_asr模型 logging.set_verbosity_warning() processor = SpeechT5Processor.from_pretrained(opt.model_path) model = SpeechT5ForSpeechToText.from_pretrained(opt.model_path).to(device) # input:encoder的语音输入 # input:encoder的输入-单声道/16kHZ example_speech, sampling_rate = librosa.load(opt.input_speech, sr=16000, dtype=np.float64) inputs = processor(audio=example_speech, sampling_rate=sampling_rate, return_tensors="pt").to(device) # output:decoder的文本输出 predicted_ids = model.generate(**inputs, max_length=100) transcription = processor.batch_decode(predicted_ids.cpu(), skip_special_tokens=True) print("text: {}".format(transcription[0])) with open(os.path.join(opt.result_path, "asr.txt"), "+w") as f: f.write("text: {}".format(transcription[0])) if __name__ == "__main__": main(parse_opt())