import tensorflow.compat.v1 as tf
import csv
import tf2onnx
import os

def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Directory '{path}' created.")
    else:
        print(f"Directory '{path}' already exists.")


def read_csv_data(file_path):
    with open(file_path, 'r') as f:
        reader = csv.reader(f)
        next(reader)
        datas = list(reader)
    for data in datas:
        data[2] = data[2][1:-1].split(",")
    
    return datas

def load_graph(model_file):
    with tf.gfile.GFile(model_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
    
    return graph

def convert_graph_to_onnx(graph, input_tensors, output_tensors, output_path):
    
    input_graph_names_list = []
    output_graph_names_list = []
    
    with graph.as_default():
        for output_tensor in output_tensors:
            output_graph_names_list.append(output_tensor[1])
        for input_tensor in input_tensors:
            input_graph_names_list.append(input_tensor[1])
        
        with tf.Session(graph=graph) as sess:
            onnx_graph = tf2onnx.tfonnx.process_tf_graph(sess.graph, input_names=input_graph_names_list, output_names=output_graph_names_list)
            model_proto = onnx_graph.make_model("test_model")
            with open(output_path, "wb") as f:
                f.write(model_proto.SerializeToString())
                print(f"ONNX model saved to {output_path}")
                
 
if __name__ == '__main__':
    
    model = "model_1"
    model_dir = "./models"
    input_tensors_path = os.path.join(model_dir, f"{model}/input_tensors.csv")
    output_tensors_path = os.path.join(model_dir, f"{model}/output_tensors.csv")
    model_path = os.path.join(model_dir, f"{model}/model.pb")
    
    onnx_model_dir = os.path.join(model_dir, f"{model}/onnx-1")
    
    onnx_model_path = os.path.join(onnx_model_dir, "model.onnx")
    
    create_directory(onnx_model_dir)
    
    input_tensors = read_csv_data(input_tensors_path)
    output_tensors = read_csv_data(output_tensors_path)
    graph = load_graph(model_path)
    convert_graph_to_onnx(graph, input_tensors, output_tensors, onnx_model_path)
    