import onnx

# 加载 ONNX 模型
model_path = "/home/sunzhq/workspace/yidong-infer/facenet/facenet/tools/onnx-models/facenet_static_bs64.onnx" # 请替换为您的实际 .onnx 文件路径
model = onnx.load(model_path)

# 检查模型是否有效
onnx.checker.check_model(model)

# 获取图 (graph) 信息
graph = model.graph

print("--- Model Info ---")
print(f"Model Name: {model.producer_name or 'Unknown'}")
print(f"ONNX Version: {model.ir_version}")

print("\n--- Input Information ---")
for input_tensor in graph.input:
    print(f"Name: {input_tensor.name}")
    print(f"Type: {input_tensor.type.tensor_type.elem_type}")
    # 解析 shape
    shape_dim = [dim.dim_param if dim.dim_param else dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim]
    print(f"Shape: {shape_dim}")
    print("-" * 20)

print("\n--- Output Information ---")
for output_tensor in graph.output:
    print(f"Name: {output_tensor.name}")
    print(f"Type: {output_tensor.type.tensor_type.elem_type}")
    # 解析 shape
    shape_dim = [dim.dim_param if dim.dim_param else dim.dim_value for dim in output_tensor.type.tensor_type.shape.dim]
    print(f"Shape: {shape_dim}")
    print("-" * 20)

# 如果你想查看所有节点 (nodes) 的概览 (可选)
# print("\n--- Node Overview ---")
# for i, node in enumerate(graph.node[:5]): # 只打印前5个节点作为示例
#     print(f"Node {i}: {node.op_type} -> {node.output[0]} (inputs: {node.input})")
# if len(graph.node) > 5:
#     print(f"... and {len(graph.node) - 5} more nodes")