import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import graph_util
import numpy as np

def fix_phase_train_and_save(input_pb_path, output_pb_path_fixed):
    with tf.io.gfile.GFile(input_pb_path, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    new_graph_def = tf.compat.v1.GraphDef()

    for node in graph_def.node:
        if node.name == 'phase_train' or (node.name == 'phase_train:0' and node.op == 'Placeholder'):
             print(f"Found phase_train node: Name='{node.name}', Op='{node.op}', Dtype={node.attr['dtype'].type}")
             print(f"  - Note: Actual output tensor name is likely '{node.name}:0'")

             from tensorflow.core.framework import node_def_pb2, attr_value_pb2, tensor_pb2, types_pb2
             from tensorflow.core.framework import tensor_shape_pb2
             from tensorflow.python.framework import tensor_util

             const_node = node_def_pb2.NodeDef()
             const_node.op = "Const"
             const_node.name = node.name # Use the same name ('phase_train')
             if node.device:
                 const_node.device = node.device
             const_node.attr["dtype"].CopyFrom(node.attr["dtype"]) # Should be DT_BOOL (types_pb2.DT_BOOL)
             false_tensor = tensor_pb2.TensorProto(
                 dtype=types_pb2.DT_BOOL,
                 bool_val=[False],
                 # tensor_shape=scalar_tensor_shape 
             )

             false_tensor.ClearField('tensor_shape') 

             const_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue(tensor=false_tensor))
             new_graph_def.node.extend([const_node])
             print(f"Replaced '{node.name}' with a scalar Const node having value False.")
        else:
            new_graph_def.node.extend([node])

    output_node_name_without_port = 'embeddings' 
    try:
        new_graph_def = graph_util.remove_training_nodes(
            input_graph_def=new_graph_def,
            protected_nodes=[] 
        )
        print(f"Applied remove_training_nodes optimization.")
    except Exception as e:
        print(f"Warning: Could not apply remove_training_nodes: {e}. Proceeding with current graph_def.")


    # Save the modified graph
    with tf.io.gfile.GFile(output_pb_path_fixed, 'wb') as f:
        f.write(new_graph_def.SerializeToString())

    print(f"Modified .pb saved to: {output_pb_path_fixed}")


input_pb = "/home/sunzhq/workspace/yidong-infer/facenet/facenet/models_m/facenet-tmp/20180408-102900.pb"
fixed_pb = "/home/sunzhq/workspace/yidong-infer/facenet/facenet/models_m/facenet-tmp/20180408-102900_fixed_scalar.pb"

fix_phase_train_and_save(input_pb, fixed_pb)

print("\nNow run tf2onnx on the fixed .pb file (with scalar phase_train):")
print(f"python -m tf2onnx.convert \\")
print(f"  --input {fixed_pb} \\")
print(f"  --inputs \"input:0[64,160,160,3]\" \\")
print(f"  --outputs embeddings:0 \\") # No more phase_train input needed
print(f"  --output ./onnx-models/facenet_static_bs64.onnx \\")
print(f"  --opset 11")