# import onnxruntime as ort # import numpy as np # # 直接加载ONNX模型查看输入要求 # model_path = "/root/.cache/espnet_onnx/transformer_lm/full/default_encoder.onnx" # try: # sess = ort.InferenceSession(model_path, providers=['CPUExecutionProvider']) # input_details = sess.get_inputs() # print("ONNX模型输入要求:") # for inp in input_details: # print(f" 名称: {inp.name}, 形状: {inp.shape}, 类型: {inp.type}") # except Exception as e: # print(f"加载模型失败: {e}") # import os # import onnx # import onnxruntime as ort # import numpy as np # # 检查ONNX模型文件 # model_path = "/root/.cache/espnet_onnx/transformer_lm/full/default_encoder.onnx" # print("检查模型文件...") # if os.path.exists(model_path): # model_size = os.path.getsize(model_path) # print(f"模型大小: {model_size} bytes") # # 加载模型查看结构 # try: # model = onnx.load(model_path) # print(f"模型IR版本: {model.ir_version}") # print(f"生产者: {model.producer_name} {model.producer_version}") # print(f"模型输入: {len(model.graph.input)} 个") # print(f"模型输出: {len(model.graph.output)} 个") # print(f"节点数量: {len(model.graph.node)}") # # 查找Where节点 # where_nodes = [node for node in model.graph.node if node.op_type == "Where"] # print(f"找到 {len(where_nodes)} 个Where节点") # for i, node in enumerate(where_nodes[:3]): # 只显示前3个 # print(f" Where节点 {i}: {node.name}") # print(f" 输入: {[input for input in node.input]}") # print(f" 输出: {[output for output in node.output]}") # except Exception as e: # print(f"加载模型失败: {e}") # else: # print(f"模型文件不存在: {model_path}") import onnxruntime as ort import numpy as np model_path = "/root/.cache/espnet_onnx/transformer_lm/full/default_encoder.onnx" print("=== 检查模型实际输入 ===") sess = ort.InferenceSession(model_path, providers=['CPUExecutionProvider']) # 详细检查输入 print("模型输入详细信息:") for inp in sess.get_inputs(): print(f"\n输入: {inp.name}") print(f" 形状: {inp.shape}") print(f" 类型: {inp.type}") # 打印每个维度 for i, dim in enumerate(inp.shape): print(f" 维度[{i}]: {dim}") # 尝试不同的输入名称 print("\n=== 尝试不同的输入名称 ===") # 创建测试数据 batch_size = 1 time_frames = 100 n_mels = 80 dummy_feats = np.random.randn(batch_size, time_frames, n_mels).astype(np.float32) # 获取所有可能的输入名称 input_names = [inp.name for inp in sess.get_inputs()] print(f"模型接受的输入名称: {input_names}") # 尝试所有可能的输入组合 test_inputs = [] # 常见的输入名称模式 common_names = [ 'feats', 'speech', 'input', 'x', 'feats_length', 'speech_lengths', 'lengths', 'ilens' ] for name in input_names: print(f"\n测试输入: {name}") # 根据名称猜测类型 if 'length' in name.lower() or 'lens' in name.lower(): # 可能是长度输入 dummy_input = np.array([time_frames], dtype=np.int64) else: # 可能是特征输入 dummy_input = dummy_feats try: outputs = sess.run(None, {name: dummy_input}) print(f" 成功! 使用单一输入: {name}") print(f" 输出数量: {len(outputs)}") for i, out in enumerate(outputs): print(f" 输出{i}: {out.shape}") break except: print(f" 失败: 单一输入{name}") # 尝试多输入 if len(input_names) > 1: print(f"\n尝试多输入组合: {input_names}") # 准备输入字典 input_dict = {} for name in input_names: if 'length' in name.lower() or 'lens' in name.lower(): input_dict[name] = np.array([time_frames], dtype=np.int64) else: input_dict[name] = dummy_feats try: outputs = sess.run(None, input_dict) print(f" 成功! 使用多输入") for i, out in enumerate(outputs): print(f" 输出{i}: {out.shape}") except Exception as e: print(f" 失败: {e}")