# import librosa # from espnet_onnx import Speech2Text # PROVIDERS = ['CPUExecutionProvider'] # tag_name = 'transformer_lm' # speech2text = Speech2Text( # tag_name, # providers=PROVIDERS, # # cache_dir="/root/.cache/espnet_onnx/transformer_lm/full/" # ) # y, sr = librosa.load('/data/datasets/data_aishell/wav/test/S0764/BAC009S0764W0150.wav', sr=16000) # nbest = speech2text(y) # runs on GPU. import librosa import numpy as np import onnxruntime as ort import os def direct_solution(audio_path, session): """直接使用已知能工作的100帧方案""" print("=== 直接解决方案:使用100帧 ===") # 加载音频 y, sr = librosa.load(audio_path, sr=16000) # 提取特征 n_fft = 512 hop_length = 128 n_mels = 80 mel_spec = librosa.feature.melspectrogram( y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels ) log_mel = librosa.power_to_db(mel_spec) features = log_mel.T print(f"原始特征: {features.shape}") # 调整到100帧 target_frames = 100 if features.shape[0] < target_frames: # 填充 feats_adjusted = np.pad( features, ((0, target_frames - features.shape[0]), (0, 0)), mode='constant', constant_values=-20 # 静音的对数梅尔值 ) else: # 截断 feats_adjusted = features[:target_frames] # 添加批次维度 feats_input = np.expand_dims(feats_adjusted, axis=0).astype(np.float32) print(f"调整后: {feats_input.shape}") # 加载编码器 encoder_path = "/home/sunzhq/workspace/yidong-infer/conformer/onnx_models/transformer_lm/full/default_encoder.onnx" # encoder = ort.InferenceSession(encoder_path, providers=['CPUExecutionProvider']) # 运行编码器 try: encoder_outputs = session.run(None, {'feats': feats_input}) print(f"✅ 编码器成功!") encoder_out = encoder_outputs[0] encoder_out_lens = encoder_outputs[1] print(f"编码输出: {encoder_out.shape}") print(f"输出长度: {encoder_out_lens}") return encoder_out, encoder_out_lens except Exception as e: print(f"❌ 编码器失败: {e}") return None, None def check_model_inputs(model_path): """检查模型的输入输出""" try: session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider']) print(f"\n模型: {os.path.basename(model_path)}") print("输入:") for inp in session.get_inputs(): print(f" {inp.name}: {inp.shape}") print("输出:") for out in session.get_outputs(): shape_str = out.shape if hasattr(out, 'shape') else '未知' print(f" {out.name}: {shape_str}") return session except Exception as e: print(f"加载模型失败: {e}") return None def run_full_pipeline_fixed(audio_path): print("=" * 60) print("ASR完整管道") print("=" * 60) # 1. 首先检查所有模型 model_dir = "/home/sunzhq/workspace/yidong-infer/conformer/onnx_models/transformer_lm/full/" print("\n=== 检查所有模型 ===") models_to_check = [ ("encoder", "default_encoder.onnx"), ("CTC", "ctc.onnx"), ("decoder", "xformer_decoder.onnx"), ("transformer_lm", "transformer_lm.onnx") ] model_sessions = {} for model_name, model_file in models_to_check: model_path = os.path.join(model_dir, model_file) if os.path.exists(model_path): session = check_model_inputs(model_path) if session: model_sessions[model_name] = session else: print(f"\n{model_name}模型不存在: {model_path}") # 2. 运行编码器 print("\n=== 运行编码器 ===") encode_session = model_sessions['encoder'] encoder_out, encoder_out_lens = direct_solution(audio_path, encode_session) if encoder_out is None: return None # 3. CTC解码(使用正确的输入名称) print("\n=== 运行CTC解码 ===") if 'CTC' in model_sessions: ctc_session = model_sessions['CTC'] # 准备CTC输入 - 根据检查结果,输入名称是'x' # 我们需要看看CTC是否需要其他输入 ctc_inputs = {} for inp in ctc_session.get_inputs(): if inp.name == 'x': # 输入名称为'x',使用编码器输出 ctc_inputs['x'] = encoder_out print(f"CTC输入 'x': {encoder_out.shape}") elif 'length' in inp.name.lower() or 'lens' in inp.name.lower(): # 长度输入 ctc_inputs[inp.name] = encoder_out_lens print(f"CTC输入 '{inp.name}': {encoder_out_lens.shape}") else: print(f"警告: 未知CTC输入 '{inp.name}'") # 尝试提供默认值 if hasattr(inp, 'shape'): # 根据形状创建默认值 dummy_shape = [1 if dim == 'batch' or dim == 'sequence' else dim for dim in inp.shape] dummy_data = np.zeros(dummy_shape, dtype=np.float32) ctc_inputs[inp.name] = dummy_data try: import pdb;pdb.set_trace() ctc_outputs = ctc_session.run(None, ctc_inputs) print(f"✅ CTC解码成功!") for i, output in enumerate(ctc_outputs): print(f" 输出{i}: {output.shape}") # CTC输出通常是log_probs,需要进一步解码 ctc_log_probs = ctc_outputs[0] # 4. 简单解码:取argmax print("\n=== 简单解码 ===") # 沿着词汇维度取argmax predicted_tokens = np.argmax(ctc_log_probs, axis=-1) print(f"预测的token索引形状: {predicted_tokens.shape}") # 读取词汇表 vocab_path = os.path.join(model_dir, "..", "tokens.txt") if os.path.exists(vocab_path): print(f"从 {vocab_path} 加载词汇表...") # 读取词汇表 with open(vocab_path, 'r', encoding='utf-8') as f: vocab = [line.strip().split()[0] for line in f] print(f"词汇表大小: {len(vocab)}") # 转换token为文本 predicted_text = [] for token_idx in predicted_tokens[0]: # 取batch中的第一个 if token_idx < len(vocab): predicted_text.append(vocab[token_idx]) # 移除空白符和重复token(CTC解码的简单版本) filtered_text = [] prev_token = None for token in predicted_text: if token != '' and token != prev_token: if token != '': filtered_text.append(token) prev_token = token final_text = ''.join(filtered_text) print(f"解码文本: {final_text}") return final_text return ctc_outputs except Exception as e: print(f"❌ CTC失败: {e}") import traceback traceback.print_exc() return None # 运行修正后的管道 result = run_full_pipeline_fixed('/data/datasets/1/data_aishell/wav/test/S0916/BAC009S0916W0314.wav') if result: print(f"\n" + "=" * 60) print("🎉 ASR识别完成!") print(f"结果: {result}") print("=" * 60)