infer.bak.py 7.75 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# import librosa
# from espnet_onnx import Speech2Text


# PROVIDERS = ['CPUExecutionProvider']
# tag_name = 'transformer_lm'

# speech2text = Speech2Text(
#   tag_name,
#   providers=PROVIDERS,
# #   cache_dir="/root/.cache/espnet_onnx/transformer_lm/full/"
# )
# y, sr = librosa.load('/data/datasets/data_aishell/wav/test/S0764/BAC009S0764W0150.wav', sr=16000)

# nbest = speech2text(y) # runs on GPU.



import librosa
import numpy as np
import onnxruntime as ort
import os

def direct_solution(audio_path, session):
    """直接使用已知能工作的100帧方案"""
    
    print("=== 直接解决方案:使用100帧 ===")
    
    # 加载音频
    y, sr = librosa.load(audio_path, sr=16000)
    
    # 提取特征
    n_fft = 512
    hop_length = 128
    n_mels = 80
    
    mel_spec = librosa.feature.melspectrogram(
        y=y, sr=sr, n_fft=n_fft,
        hop_length=hop_length, n_mels=n_mels
    )
    log_mel = librosa.power_to_db(mel_spec)
    features = log_mel.T
    
    print(f"原始特征: {features.shape}")
    
    # 调整到100帧
    target_frames = 100
    
    if features.shape[0] < target_frames:
        # 填充
        feats_adjusted = np.pad(
            features,
            ((0, target_frames - features.shape[0]), (0, 0)),
            mode='constant',
            constant_values=-20  # 静音的对数梅尔值
        )
    else:
        # 截断
        feats_adjusted = features[:target_frames]
    
    # 添加批次维度
    feats_input = np.expand_dims(feats_adjusted, axis=0).astype(np.float32)
    
    print(f"调整后: {feats_input.shape}")
    
    # 加载编码器
    encoder_path = "/home/sunzhq/workspace/yidong-infer/conformer/onnx_models/transformer_lm/full/default_encoder.onnx"
    # encoder = ort.InferenceSession(encoder_path, providers=['CPUExecutionProvider'])
    
    # 运行编码器
    try:
        encoder_outputs = session.run(None, {'feats': feats_input})
        print(f"✅ 编码器成功!")
        
        encoder_out = encoder_outputs[0]
        encoder_out_lens = encoder_outputs[1]
        
        print(f"编码输出: {encoder_out.shape}")
        print(f"输出长度: {encoder_out_lens}")
        
        return encoder_out, encoder_out_lens
        
    except Exception as e:
        print(f"❌ 编码器失败: {e}")
        return None, None

def check_model_inputs(model_path):
    """检查模型的输入输出"""
    try:
        session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
        print(f"\n模型: {os.path.basename(model_path)}")
        print("输入:")
        for inp in session.get_inputs():
            print(f"  {inp.name}: {inp.shape}")
        print("输出:")
        for out in session.get_outputs():
            shape_str = out.shape if hasattr(out, 'shape') else '未知'
            print(f"  {out.name}: {shape_str}")
        return session
    except Exception as e:
        print(f"加载模型失败: {e}")
        return None

def run_full_pipeline_fixed(audio_path):
    
    print("=" * 60)
    print("ASR完整管道")
    print("=" * 60)
    
    # 1. 首先检查所有模型
    model_dir = "/home/sunzhq/workspace/yidong-infer/conformer/onnx_models/transformer_lm/full/"
    
    print("\n=== 检查所有模型 ===")
    
    models_to_check = [
        ("encoder", "default_encoder.onnx"),
        ("CTC", "ctc.onnx"),
        ("decoder", "xformer_decoder.onnx"),
        ("transformer_lm", "transformer_lm.onnx")
    ]
    
    model_sessions = {}
    
    for model_name, model_file in models_to_check:
        model_path = os.path.join(model_dir, model_file)
        if os.path.exists(model_path):
            session = check_model_inputs(model_path)
            if session:
                model_sessions[model_name] = session
        else:
            print(f"\n{model_name}模型不存在: {model_path}")
    
    # 2. 运行编码器
    print("\n=== 运行编码器 ===")
    encode_session = model_sessions['encoder']
    encoder_out, encoder_out_lens = direct_solution(audio_path, encode_session)
    
    if encoder_out is None:
        return None
    
    # 3. CTC解码(使用正确的输入名称)
    print("\n=== 运行CTC解码 ===")
    
    if 'CTC' in model_sessions:
        ctc_session = model_sessions['CTC']
        
        # 准备CTC输入 - 根据检查结果,输入名称是'x'
        # 我们需要看看CTC是否需要其他输入
        ctc_inputs = {}
        
        for inp in ctc_session.get_inputs():
            if inp.name == 'x':
                # 输入名称为'x',使用编码器输出
                ctc_inputs['x'] = encoder_out
                print(f"CTC输入 'x': {encoder_out.shape}")
            elif 'length' in inp.name.lower() or 'lens' in inp.name.lower():
                # 长度输入
                ctc_inputs[inp.name] = encoder_out_lens
                print(f"CTC输入 '{inp.name}': {encoder_out_lens.shape}")
            else:
                print(f"警告: 未知CTC输入 '{inp.name}'")
                # 尝试提供默认值
                if hasattr(inp, 'shape'):
                    # 根据形状创建默认值
                    dummy_shape = [1 if dim == 'batch' or dim == 'sequence' else dim 
                                  for dim in inp.shape]
                    dummy_data = np.zeros(dummy_shape, dtype=np.float32)
                    ctc_inputs[inp.name] = dummy_data
        
        try:
            import pdb;pdb.set_trace()
            ctc_outputs = ctc_session.run(None, ctc_inputs)
            print(f"✅ CTC解码成功!")
            
            for i, output in enumerate(ctc_outputs):
                print(f"  输出{i}: {output.shape}")
            
            # CTC输出通常是log_probs,需要进一步解码
            ctc_log_probs = ctc_outputs[0]
            
            # 4. 简单解码:取argmax
            print("\n=== 简单解码 ===")
            # 沿着词汇维度取argmax
            predicted_tokens = np.argmax(ctc_log_probs, axis=-1)
            print(f"预测的token索引形状: {predicted_tokens.shape}")
            
            # 读取词汇表
            vocab_path = os.path.join(model_dir, "..", "tokens.txt")
            if os.path.exists(vocab_path):
                print(f"从 {vocab_path} 加载词汇表...")
                # 读取词汇表
                with open(vocab_path, 'r', encoding='utf-8') as f:
                    vocab = [line.strip().split()[0] for line in f]
                
                print(f"词汇表大小: {len(vocab)}")
                
                # 转换token为文本
                predicted_text = []
                for token_idx in predicted_tokens[0]:  # 取batch中的第一个
                    if token_idx < len(vocab):
                        predicted_text.append(vocab[token_idx])
                
                # 移除空白符和重复token(CTC解码的简单版本)
                filtered_text = []
                prev_token = None
                for token in predicted_text:
                    if token != '<blank>' and token != prev_token:
                        if token != '<sos/eos>':
                            filtered_text.append(token)
                    prev_token = token
                
                final_text = ''.join(filtered_text)
                print(f"解码文本: {final_text}")
                
                return final_text
            
            return ctc_outputs
            
        except Exception as e:
            print(f"❌ CTC失败: {e}")
            import traceback
            traceback.print_exc()
    
    return None

# 运行修正后的管道
result = run_full_pipeline_fixed('/data/datasets/1/data_aishell/wav/test/S0916/BAC009S0916W0314.wav')

if result:
    print(f"\n" + "=" * 60)
    print("🎉 ASR识别完成!")
    print(f"结果: {result}")
    print("=" * 60)