speech_asr.py 1.76 KB
Newer Older
changhl's avatar
changhl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from transformers import SpeechT5Processor, SpeechT5ForSpeechToText
from transformers import logging
from datasets import load_dataset
import torch
import argparse
import librosa
import numpy as np
import os

def parse_opt(known=False):
    parser = argparse.ArgumentParser()
    parser.add_argument('-hip', '--hip-device', type=int, default=0, help="initial hip devices")
    parser.add_argument('-m', '--model-path', type=str, default="", help="initial model path")
    parser.add_argument('-is', '--input_speech', type=str, default="Autumn, the season of change.", help="input speech")
    parser.add_argument('-res', '--result_path', type=str, default="../res", help="the path to save wav file")
    opt = parser.parse_known_args()[0] if known else parser.parse_args()
    return opt

def main(opt):
    device = torch.device(f"cuda:{int(opt.hip_device)}")
    print(f"Using device: {device}")

    # 初始化speechT5_asr模型
    logging.set_verbosity_warning()
    processor = SpeechT5Processor.from_pretrained(opt.model_path)
    model = SpeechT5ForSpeechToText.from_pretrained(opt.model_path).to(device)

    # input:encoder的语音输入
    # input:encoder的输入-单声道/16kHZ
    example_speech, sampling_rate = librosa.load(opt.input_speech, sr=16000, dtype=np.float64)
    inputs = processor(audio=example_speech, sampling_rate=sampling_rate, return_tensors="pt").to(device)

    # output:decoder的文本输出
    predicted_ids = model.generate(**inputs, max_length=100)
    transcription = processor.batch_decode(predicted_ids.cpu(), skip_special_tokens=True)
    print("text: {}".format(transcription[0]))
    with open(os.path.join(opt.result_path, "asr.txt"), "+w") as f:
        f.write("text: {}".format(transcription[0]))

if __name__ == "__main__":
    main(parse_opt())