show_pb_model_name.py 2.84 KB
Newer Older
yaoht's avatar
yaoht committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# import tensorflow as tf
# from tensorflow.python.framework import graph_util
# from tensorflow.python.framework import graph_io

# def load_graph_def(pb_file):
#     with tf.io.gfile.GFile(pb_file, "rb") as f:
#         graph_def = tf.compat.v1.GraphDef()
#         graph_def.ParseFromString(f.read())
#     return graph_def

# def get_input_output_names(graph_def):
#     input_names = []
#     output_names = []
#     for node in graph_def.node:
#         if node.op == 'Placeholder':
#             input_names.append(node.name)
#         # Identify output nodes as those not used as inputs to other nodes
#         is_output = True
#         for n in graph_def.node:
#             if node.name in n.input:
#                 is_output = False
#                 break
#         if is_output:
#             output_names.append(node.name)
#     return input_names, output_names

# # 指定GraphDef pb文件路径
# pb_file_path = "resnet50v15_tf.pb"

# # 加载GraphDef
# graph_def = load_graph_def(pb_file_path)

# # 获取输入和输出的节点名称
# input_names, output_names = get_input_output_names(graph_def)

# print("Input names:", input_names)
# print("Output names:", output_names)

import tensorflow as tf
from tensorflow.python.framework import tensor_util

def load_graph_def(pb_file):
    with tf.io.gfile.GFile(pb_file, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    return graph_def

def get_graph_inputs_outputs(graph_def):
    inputs = []
    outputs = []
    
    for node in graph_def.node:
        if node.op == 'Placeholder':
            shape = None
            for attr_value in node.attr.values():
                if attr_value.HasField('shape'):
                    shape = [dim.size for dim in attr_value.shape.dim]
            inputs.append({'name': node.name, 'shape': shape})
        # Assuming outputs are nodes with no outputs themselves, usually not a strict rule
        elif not any(node.name in input for input in [n.input for n in graph_def.node]):
            shape = None
            try:
                tensor_shape = tensor_util.MakeNdarray(node.attr["shape"].shape)
                shape = tensor_shape.shape
            except:
                pass
            outputs.append({'name': node.name, 'shape': shape})
    
    return inputs, outputs

def print_graph_info(inputs, outputs):
    print("Inputs:")
    for input_info in inputs:
        print(f"Name: {input_info['name']}, Shape: {input_info['shape']}")
    print("\nOutputs:")
    for output_info in outputs:
        print(f"Name: {output_info['name']}, Shape: {output_info['shape']}")

# Path to your .pb file
pb_file_path = "resnet50v15_tf.pb"

# Load GraphDef
graph_def = load_graph_def(pb_file_path)

# Get inputs and outputs
inputs, outputs = get_graph_inputs_outputs(graph_def)

# Print inputs and outputs
print_graph_info(inputs, outputs)