infer.py 4.48 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# import librosa
# from espnet_onnx import Speech2Text
# import numpy as np

# PROVIDERS = ['CPUExecutionProvider']
# tag_name = 'transformer_lm'

# speech2text = Speech2Text(
#     providers=PROVIDERS,
#     model_dir="/home/sunzhq/workspace/yidong-infer/conformer/onnx_models/transformer_lm"
# )

# # 加载音频
# audio_path = '/data/datasets/1/data_aishell/wav/test/S0916/BAC009S0916W0314.wav'
# y, sr = librosa.load(audio_path, sr=16000)

# print(f"原始音频长度: {len(y)} 采样点, {len(y)/sr:.2f} 秒")

# # 根据错误信息,索引778超出范围,对应约512帧
# # 假设帧移10ms,512帧对应5.12秒音频
# max_seconds = 4  # 根据你的测试调整
# max_samples = int(max_seconds * sr)

# # 如果音频太长,裁剪
# if len(y) > max_samples:
#     print(f"音频过长,裁剪到前{max_seconds}秒 ({max_samples} 采样点)")
#     y = y[:max_samples]

# # 现在尝试推理
# try:
#     nbest = speech2text(y)
#     # import pdb;pdb.set_trace()
#     print(f"识别结果: {nbest[0][0]}")
# except Exception as e:
#     print(f"错误: {e}")


import librosa
from espnet_onnx import Speech2Text
import numpy as np

def process_long_audio(audio_path, model_dir, max_chunk_seconds=10, sr=16000):
    """处理长音频"""
    # 加载模型
    speech2text = Speech2Text(
        providers=['ROCMExecutionProvider'],
        model_dir=model_dir
    )
    
    # 加载音频
    y, sr = librosa.load(audio_path, sr=sr)
    print(f"音频总长: {len(y)/sr:.2f}秒 ({len(y)}采样点)")
    
    # 尝试整段处理
    try:
        nbest = speech2text(y)
        print(f"整段识别成功: {nbest[0][0]}")
        return nbest[0][0]
    except Exception as e:
        print(f"整段处理失败,开始分块处理: {e}")
    
    # 确定最大安全长度
    max_samples = find_max_safe_length(speech2text, sr, max_chunk_seconds)
    
    # 分块处理
    results = []
    overlap = int(1 * sr)  # 300ms重叠
    
    for start in range(0, len(y), max_samples - overlap):
        end = min(start + max_samples, len(y))
        chunk = y[start:end]
        
        # 跳过太短的块
        if len(chunk) < 0.5 * sr:
            continue
            
        duration = len(chunk) / sr
        print(f"处理 {start/sr:.1f}s-{end/sr:.1f}s ({duration:.1f}秒)...")
        
        try:
            nbest = speech2text(chunk)
            if nbest and nbest[0]:
                results.append(nbest[0][0])
        except Exception as e:
            print(f"块处理失败: {e}")
            # 尝试更小的块
            sub_results = process_with_smaller_chunks(chunk, speech2text, sr)
            results.extend(sub_results)
    
    # 合并结果(简单合并,可根据需要添加更智能的合并逻辑)
    full_text = " ".join(results)
    
    # 后处理:移除重复的词语
    full_text = post_process_text(full_text)
    
    return full_text

def find_max_safe_length(model, sr, initial_max=10):
    """通过二分查找确定最大安全长度"""
    max_samples = int(initial_max * sr)
    min_samples = int(1 * sr)
    
    # 使用测试音频(静音或简单语音)
    test_audio = np.random.randn(max_samples) * 0.01  # 低音量噪音
    
    for test_len in range(max_samples, min_samples, -int(0.5*sr)):
        try:
            model(test_audio[:test_len])
            print(f"安全长度: {test_len/sr:.1f}秒")
            return test_len
        except:
            continue
    
    return min_samples

def process_with_smaller_chunks(audio, model, sr, chunk_size=5):
    """使用更小的块处理音频"""
    results = []
    chunk_samples = int(chunk_size * sr)
    
    for i in range(0, len(audio), chunk_samples):
        chunk = audio[i:i+chunk_samples]
        if len(chunk) > 0.5 * sr:
            try:
                nbest = model(chunk)
                if nbest and nbest[0]:
                    results.append(nbest[0][0])
            except:
                pass
    return results

def post_process_text(text):
    """后处理文本,移除可能的重复"""
    # 简单的重复词移除(可根据需要增强)
    words = text.split()
    cleaned = []
    for i, word in enumerate(words):
        if i == 0 or word != words[i-1]:
            cleaned.append(word)
    return " ".join(cleaned)

# 使用示例
audio_path = '/data/datasets/1/data_aishell/wav/test/S0768/BAC009S0768W0452.wav'
model_dir = "/home/sunzhq/workspace/yidong-infer/conformer/onnx_models/transformer_lm"

result = process_long_audio(audio_path, model_dir)
print(f"\n最终识别结果: {result}")