# import tensorflow as tf # from tensorflow.python.framework import graph_util # from tensorflow.python.framework import graph_io # def load_graph_def(pb_file): # with tf.io.gfile.GFile(pb_file, "rb") as f: # graph_def = tf.compat.v1.GraphDef() # graph_def.ParseFromString(f.read()) # return graph_def # def get_input_output_names(graph_def): # input_names = [] # output_names = [] # for node in graph_def.node: # if node.op == 'Placeholder': # input_names.append(node.name) # # Identify output nodes as those not used as inputs to other nodes # is_output = True # for n in graph_def.node: # if node.name in n.input: # is_output = False # break # if is_output: # output_names.append(node.name) # return input_names, output_names # # 指定GraphDef pb文件路径 # pb_file_path = "resnet50v15_tf.pb" # # 加载GraphDef # graph_def = load_graph_def(pb_file_path) # # 获取输入和输出的节点名称 # input_names, output_names = get_input_output_names(graph_def) # print("Input names:", input_names) # print("Output names:", output_names) import tensorflow as tf from tensorflow.python.framework import tensor_util def load_graph_def(pb_file): with tf.io.gfile.GFile(pb_file, "rb") as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) return graph_def def get_graph_inputs_outputs(graph_def): inputs = [] outputs = [] for node in graph_def.node: if node.op == 'Placeholder': shape = None for attr_value in node.attr.values(): if attr_value.HasField('shape'): shape = [dim.size for dim in attr_value.shape.dim] inputs.append({'name': node.name, 'shape': shape}) # Assuming outputs are nodes with no outputs themselves, usually not a strict rule elif not any(node.name in input for input in [n.input for n in graph_def.node]): shape = None try: tensor_shape = tensor_util.MakeNdarray(node.attr["shape"].shape) shape = tensor_shape.shape except: pass outputs.append({'name': node.name, 'shape': shape}) return inputs, outputs def print_graph_info(inputs, outputs): print("Inputs:") for input_info in inputs: print(f"Name: {input_info['name']}, Shape: {input_info['shape']}") print("\nOutputs:") for output_info in outputs: print(f"Name: {output_info['name']}, Shape: {output_info['shape']}") # Path to your .pb file pb_file_path = "resnet50v15_tf.pb" # Load GraphDef graph_def = load_graph_def(pb_file_path) # Get inputs and outputs inputs, outputs = get_graph_inputs_outputs(graph_def) # Print inputs and outputs print_graph_info(inputs, outputs)