README.md 8.1 KB
Newer Older
yaoht's avatar
yaoht committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# tf转onnx

## 环境准备

- tensorflow安装

```bash
pip install tensorflow
```

- tf2onnx 安装(version>= 1.5.5)

```bash
pip install tf2onnx
```

## 模型格式确认

请先确认手里的模型文件的格式,一般情况下:

1. **SavedModel 文件结构**

   - `saved_model.pb``saved_model.pbtxt`:这是SavedModel的核心文件,包含了模型的图(graph)和元数据(metadata)。
   - `variables/`:这个文件夹包含两个文件,`variables.data-?????-of-?????``variables.index`,存储了模型的变量。
   - `assets/`(可选):这个文件夹存储了任何附加的资源文件。

   如果你的`.pb`文件位于一个包含上述结构的目录中,那么它很可能是一个SavedModel。

2. **Checkpoint 文件结构**

   - Checkpoint 通常包含三个文件:一个`.index`文件,一个或多个`.data-?????-of-?????`文件,以及一个`checkpoint`文件,这个文件是保存模型变量的。

   如果你的`.pb`文件位于一个包含上述结构的目录中,那么它很可能是一个Checkpoint。

3. **GraphDef 文件结构**

   - 如果只有一个`.pb`文件,且没有与其相关联的其他文件或目录结构,那么它很可能是GraphDef。

可以使用一下代码对模型文件格式进行检查:

```python
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2

def is_saved_model(model_dir):
    try:
        model = tf.saved_model.load(model_dir)
        return True
    except Exception:
        return False

def is_graph_def(pb_file):
    try:
        with tf.io.gfile.GFile(pb_file, "rb") as f:
            graph_def = tf.compat.v1.GraphDef()
            graph_def.ParseFromString(f.read())
        return True
    except Exception:
        return False

def is_checkpoint(model_dir):
    try:
        checkpoint = tf.train.Checkpoint()
        checkpoint.restore(model_dir).expect_partial()
        return True
    except Exception:
        return False

model_path = "/path/to/model"

if is_saved_model(model_path):
    print(f"{model_path} contains a SavedModel.")
elif is_graph_def(model_path):
    print(f"{model_path} contains a GraphDef.")
elif is_checkpoint(model_path):
    print(f"{model_path} contains a Checkpoint.")
else:
    print(f"{model_path} format is unknown.")

```

## 模型输入输出的name和shape确认

使用下面代码对GraphDef格式的模型进行确认

```python
import tensorflow as tf
from tensorflow.python.framework import tensor_util

def load_graph_def(pb_file):
    with tf.io.gfile.GFile(pb_file, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
    return graph_def

def get_graph_inputs_outputs(graph_def):
    inputs = []
    outputs = []
    
    for node in graph_def.node:
        if node.op == 'Placeholder':
            shape = None
            for attr_value in node.attr.values():
                if attr_value.HasField('shape'):
                    shape = [dim.size for dim in attr_value.shape.dim]
            inputs.append({'name': node.name, 'shape': shape})
        # Assuming outputs are nodes with no outputs themselves, usually not a strict rule
        elif not any(node.name in input for input in [n.input for n in graph_def.node]):
            shape = None
            try:
                tensor_shape = tensor_util.MakeNdarray(node.attr["shape"].shape)
                shape = tensor_shape.shape
            except:
                pass
            outputs.append({'name': node.name, 'shape': shape})
    
    return inputs, outputs

def print_graph_info(inputs, outputs):
    print("Inputs:")
    for input_info in inputs:
        print(f"Name: {input_info['name']}, Shape: {input_info['shape']}")
    print("\nOutputs:")
    for output_info in outputs:
        print(f"Name: {output_info['name']}, Shape: {output_info['shape']}")

# Path to your .pb file
pb_file_path = "resnet50v15_tf.pb"

# Load GraphDef
graph_def = load_graph_def(pb_file_path)

# Get inputs and outputs
inputs, outputs = get_graph_inputs_outputs(graph_def)

# Print inputs and outputs
print_graph_info(inputs, outputs)
```

## 模型转换

使用tf2onnx工具进行模型转换,详细工具说明可以查看tf2onnx工具官网,tf2onnx项目地址:https://github.com/onnx/tensorflow-onnx 建议大家阅读 tf2onnx 的 README.md 文件,里面有详细的对该工具各个参数的说明。

```
options:
  -h, --help            show this help message and exit
  --input INPUT         input from graphdef
  --graphdef GRAPHDEF   input from graphdef
  --saved-model SAVED_MODEL
                        input from saved model
  --tag TAG             tag to use for saved_model
  --signature_def SIGNATURE_DEF
                        signature_def from saved_model to use
  --concrete_function CONCRETE_FUNCTION
                        For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored)
  --checkpoint CHECKPOINT
                        input from checkpoint
  --keras KERAS         input from keras model
  --tflite TFLITE       input from tflite model
  --tfjs TFJS           input from tfjs model
  --large_model         use the large model format (for models > 2GB)
  --output OUTPUT       output model file
  --inputs INPUTS       model input_names (optional for saved_model, keras, and tflite)
  --outputs OUTPUTS     model output_names (optional for saved_model, keras, and tflite)
  --ignore_default IGNORE_DEFAULT
                        comma-separated list of names of PlaceholderWithDefault ops to change into Placeholder ops
  --use_default USE_DEFAULT
                        comma-separated list of names of PlaceholderWithDefault ops to change into Identity ops using
                        their default value
  --rename-inputs RENAME_INPUTS
                        input names to use in final model (optional)
  --rename-outputs RENAME_OUTPUTS
                        output names to use in final model (optional)
  --use-graph-names     (saved model only) skip renaming io using signature names
  --opset OPSET         opset version to use for onnx domain
  --dequantize          remove quantization from model. Only supported for tflite currently.
  --custom-ops CUSTOM_OPS
                        comma-separated map of custom ops to domains in format OpName:domain. Domain
                        'ai.onnx.converters.tensorflow' is used by default.
  --extra_opset EXTRA_OPSET
                        extra opset with format like domain:version, e.g. com.microsoft:1
  --load_op_libraries LOAD_OP_LIBRARIES
                        comma-separated list of tf op library paths to register before loading model
  --target {rs4,rs5,rs6,caffe2,tensorrt,nhwc}
                        target platform
  --continue_on_error   continue_on_error
  --verbose, -v         verbose output, option is additive
  --debug               debug mode
  --output_frozen_graph OUTPUT_FROZEN_GRAPH
                        output frozen tf graph to file
  --inputs-as-nchw INPUTS_AS_NCHW
                        transpose inputs as from nhwc to nchw
  --outputs-as-nchw OUTPUTS_AS_NCHW
                        transpose outputs as from nhwc to nchw

Usage Examples:

python -m tf2onnx.convert --saved-model saved_model_dir --output model.onnx
python -m tf2onnx.convert --input frozen_graph.pb  --inputs X:0 --outputs output:0 --output model.onnx
python -m tf2onnx.convert --checkpoint checkpoint.meta  --inputs X:0 --outputs output:0 --output model.onn
```

下面是将GraphDef格式的模型转换onnx的示例, resnet50v15_tf.pb模型[下载地址](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/003_Atc_Models/modelzoo/Official/cv/Resnet50v1.5_for_ACL/resnet50v15_tf.pb)

```bash
python -m tf2onnx.convert --graphdef  resnet50v15_tf.pb --output model_nchw.onnx --inputs input_tensor:0 --outputs global_step:0,ArgMax:0,softmax_tensor:0 --inputs-as-nchw input_tensor:0
```

# tflite转onnx

同样使用tf2onnx工具,例如将ResNet50.tflite模型转为onnx模型,模型[下载地址](https://hf-mirror.com/qualcomm/ResNet50/resolve/main/ResNet50.tflite?download=true)

```bash
python -m tf2onnx.convert --opset 16 --tflite ResNet50.tflite --output model.onnx
```