import tensorflow as tf from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 def is_saved_model(model_dir): try: model = tf.saved_model.load(model_dir) return True except Exception: return False def is_graph_def(pb_file): try: with tf.io.gfile.GFile(pb_file, "rb") as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) return True except Exception: return False def is_checkpoint(model_dir): try: checkpoint = tf.train.Checkpoint() checkpoint.restore(model_dir).expect_partial() return True except Exception: return False model_path = "predict_net.pb" if is_saved_model(model_path): print(f"{model_path} contains a SavedModel.") elif is_graph_def(model_path): print(f"{model_path} contains a GraphDef.") elif is_checkpoint(model_path): print(f"{model_path} contains a Checkpoint.") else: print(f"{model_path} format is unknown.")