check_model_format.py 1.01 KB
Newer Older
yaoht's avatar
yaoht committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2

def is_saved_model(model_dir):
    try:
        model = tf.saved_model.load(model_dir)
        return True
    except Exception:
        return False

def is_graph_def(pb_file):
    try:
        with tf.io.gfile.GFile(pb_file, "rb") as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())
        return True
    except Exception:
        return False

def is_checkpoint(model_dir):
    try:
        checkpoint = tf.train.Checkpoint()
        checkpoint.restore(model_dir).expect_partial()
        return True
    except Exception:
        return False

model_path = "predict_net.pb"

if is_saved_model(model_path):
    print(f"{model_path} contains a SavedModel.")
elif is_graph_def(model_path):
    print(f"{model_path} contains a GraphDef.")
elif is_checkpoint(model_path):
    print(f"{model_path} contains a Checkpoint.")
else:
    print(f"{model_path} format is unknown.")