# import tensorflow as tf
# from tensorflow.python.framework import graph_util
# from tensorflow.python.framework import graph_io

# def load_graph_def(pb_file):
#     with tf.io.gfile.GFile(pb_file, "rb") as f:
#         graph_def = tf.compat.v1.GraphDef()
#         graph_def.ParseFromString(f.read())
#     return graph_def

# def get_input_output_names(graph_def):
#     input_names = []
#     output_names = []
#     for node in graph_def.node:
#         if node.op == 'Placeholder':
#             input_names.append(node.name)
#         # Identify output nodes as those not used as inputs to other nodes
#         is_output = True
#         for n in graph_def.node:
#             if node.name in n.input:
#                 is_output = False
#                 break
#         if is_output:
#             output_names.append(node.name)
#     return input_names, output_names

# # 指定GraphDef pb文件路径
# pb_file_path = "resnet50v15_tf.pb"

# # 加载GraphDef
# graph_def = load_graph_def(pb_file_path)

# # 获取输入和输出的节点名称
# input_names, output_names = get_input_output_names(graph_def)

# print("Input names:", input_names)
# print("Output names:", output_names)

import tensorflow as tf
from tensorflow.python.framework import tensor_util

def load_graph_def(pb_file):
    with tf.io.gfile.GFile(pb_file, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    return graph_def

def get_graph_inputs_outputs(graph_def):
    inputs = []
    outputs = []
    
    for node in graph_def.node:
        if node.op == 'Placeholder':
            shape = None
            for attr_value in node.attr.values():
                if attr_value.HasField('shape'):
                    shape = [dim.size for dim in attr_value.shape.dim]
            inputs.append({'name': node.name, 'shape': shape})
        # Assuming outputs are nodes with no outputs themselves, usually not a strict rule
        elif not any(node.name in input for input in [n.input for n in graph_def.node]):
            shape = None
            try:
                tensor_shape = tensor_util.MakeNdarray(node.attr["shape"].shape)
                shape = tensor_shape.shape
            except:
                pass
            outputs.append({'name': node.name, 'shape': shape})
    
    return inputs, outputs

def print_graph_info(inputs, outputs):
    print("Inputs:")
    for input_info in inputs:
        print(f"Name: {input_info['name']}, Shape: {input_info['shape']}")
    print("\nOutputs:")
    for output_info in outputs:
        print(f"Name: {output_info['name']}, Shape: {output_info['shape']}")

# Path to your .pb file
pb_file_path = "resnet50v15_tf.pb"

# Load GraphDef
graph_def = load_graph_def(pb_file_path)

# Get inputs and outputs
inputs, outputs = get_graph_inputs_outputs(graph_def)

# Print inputs and outputs
print_graph_info(inputs, outputs)



