check.py 4.2 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# import onnxruntime as ort
# import numpy as np

# # 直接加载ONNX模型查看输入要求
# model_path = "/root/.cache/espnet_onnx/transformer_lm/full/default_encoder.onnx"
# try:
#     sess = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
#     input_details = sess.get_inputs()
#     print("ONNX模型输入要求:")
#     for inp in input_details:
#         print(f"  名称: {inp.name}, 形状: {inp.shape}, 类型: {inp.type}")
# except Exception as e:
#     print(f"加载模型失败: {e}")




# import os
# import onnx
# import onnxruntime as ort
# import numpy as np

# # 检查ONNX模型文件
# model_path = "/root/.cache/espnet_onnx/transformer_lm/full/default_encoder.onnx"

# print("检查模型文件...")
# if os.path.exists(model_path):
#     model_size = os.path.getsize(model_path)
#     print(f"模型大小: {model_size} bytes")
    
#     # 加载模型查看结构
#     try:
#         model = onnx.load(model_path)
#         print(f"模型IR版本: {model.ir_version}")
#         print(f"生产者: {model.producer_name} {model.producer_version}")
#         print(f"模型输入: {len(model.graph.input)} 个")
#         print(f"模型输出: {len(model.graph.output)} 个")
#         print(f"节点数量: {len(model.graph.node)}")
        
#         # 查找Where节点
#         where_nodes = [node for node in model.graph.node if node.op_type == "Where"]
#         print(f"找到 {len(where_nodes)} 个Where节点")
        
#         for i, node in enumerate(where_nodes[:3]):  # 只显示前3个
#             print(f"  Where节点 {i}: {node.name}")
#             print(f"    输入: {[input for input in node.input]}")
#             print(f"    输出: {[output for output in node.output]}")
            
#     except Exception as e:
#         print(f"加载模型失败: {e}")
# else:
#     print(f"模型文件不存在: {model_path}")



import onnxruntime as ort
import numpy as np

model_path = "/root/.cache/espnet_onnx/transformer_lm/full/default_encoder.onnx"

print("=== 检查模型实际输入 ===")
sess = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])

# 详细检查输入
print("模型输入详细信息:")
for inp in sess.get_inputs():
    print(f"\n输入: {inp.name}")
    print(f"  形状: {inp.shape}")
    print(f"  类型: {inp.type}")
    
    # 打印每个维度
    for i, dim in enumerate(inp.shape):
        print(f"    维度[{i}]: {dim}")

# 尝试不同的输入名称
print("\n=== 尝试不同的输入名称 ===")

# 创建测试数据
batch_size = 1
time_frames = 100
n_mels = 80

dummy_feats = np.random.randn(batch_size, time_frames, n_mels).astype(np.float32)

# 获取所有可能的输入名称
input_names = [inp.name for inp in sess.get_inputs()]
print(f"模型接受的输入名称: {input_names}")

# 尝试所有可能的输入组合
test_inputs = []

# 常见的输入名称模式
common_names = [
    'feats', 'speech', 'input', 'x',
    'feats_length', 'speech_lengths', 'lengths', 'ilens'
]

for name in input_names:
    print(f"\n测试输入: {name}")
    
    # 根据名称猜测类型
    if 'length' in name.lower() or 'lens' in name.lower():
        # 可能是长度输入
        dummy_input = np.array([time_frames], dtype=np.int64)
    else:
        # 可能是特征输入
        dummy_input = dummy_feats
    
    try:
        outputs = sess.run(None, {name: dummy_input})
        print(f"  成功! 使用单一输入: {name}")
        print(f"  输出数量: {len(outputs)}")
        for i, out in enumerate(outputs):
            print(f"    输出{i}: {out.shape}")
        break
    except:
        print(f"  失败: 单一输入{name}")

# 尝试多输入
if len(input_names) > 1:
    print(f"\n尝试多输入组合: {input_names}")
    
    # 准备输入字典
    input_dict = {}
    for name in input_names:
        if 'length' in name.lower() or 'lens' in name.lower():
            input_dict[name] = np.array([time_frames], dtype=np.int64)
        else:
            input_dict[name] = dummy_feats
    
    try:
        outputs = sess.run(None, input_dict)
        print(f"  成功! 使用多输入")
        for i, out in enumerate(outputs):
            print(f"    输出{i}: {out.shape}")
    except Exception as e:
        print(f"  失败: {e}")