# tf转onnx ## 环境准备 - tensorflow安装 ```bash pip install tensorflow ``` - tf2onnx 安装(version>= 1.5.5) ```bash pip install tf2onnx ``` ## 模型格式确认 请先确认手里的模型文件的格式,一般情况下: 1. **SavedModel 文件结构**: - `saved_model.pb` 或 `saved_model.pbtxt`:这是SavedModel的核心文件,包含了模型的图(graph)和元数据(metadata)。 - `variables/`:这个文件夹包含两个文件,`variables.data-?????-of-?????` 和 `variables.index`,存储了模型的变量。 - `assets/`(可选):这个文件夹存储了任何附加的资源文件。 如果你的`.pb`文件位于一个包含上述结构的目录中,那么它很可能是一个SavedModel。 2. **Checkpoint 文件结构**: - Checkpoint 通常包含三个文件:一个`.index`文件,一个或多个`.data-?????-of-?????`文件,以及一个`checkpoint`文件,这个文件是保存模型变量的。 如果你的`.pb`文件位于一个包含上述结构的目录中,那么它很可能是一个Checkpoint。 3. **GraphDef 文件结构**: - 如果只有一个`.pb`文件,且没有与其相关联的其他文件或目录结构,那么它很可能是GraphDef。 可以使用一下代码对模型文件格式进行检查: ```python import tensorflow as tf from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 def is_saved_model(model_dir): try: model = tf.saved_model.load(model_dir) return True except Exception: return False def is_graph_def(pb_file): try: with tf.io.gfile.GFile(pb_file, "rb") as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) return True except Exception: return False def is_checkpoint(model_dir): try: checkpoint = tf.train.Checkpoint() checkpoint.restore(model_dir).expect_partial() return True except Exception: return False model_path = "/path/to/model" if is_saved_model(model_path): print(f"{model_path} contains a SavedModel.") elif is_graph_def(model_path): print(f"{model_path} contains a GraphDef.") elif is_checkpoint(model_path): print(f"{model_path} contains a Checkpoint.") else: print(f"{model_path} format is unknown.") ``` ## 模型输入输出的name和shape确认 使用下面代码对GraphDef格式的模型进行确认 ```python import tensorflow as tf from tensorflow.python.framework import tensor_util def load_graph_def(pb_file): with tf.io.gfile.GFile(pb_file, "rb") as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) return graph_def def get_graph_inputs_outputs(graph_def): inputs = [] outputs = [] for node in graph_def.node: if node.op == 'Placeholder': shape = None for attr_value in node.attr.values(): if attr_value.HasField('shape'): shape = [dim.size for dim in attr_value.shape.dim] inputs.append({'name': node.name, 'shape': shape}) # Assuming outputs are nodes with no outputs themselves, usually not a strict rule elif not any(node.name in input for input in [n.input for n in graph_def.node]): shape = None try: tensor_shape = tensor_util.MakeNdarray(node.attr["shape"].shape) shape = tensor_shape.shape except: pass outputs.append({'name': node.name, 'shape': shape}) return inputs, outputs def print_graph_info(inputs, outputs): print("Inputs:") for input_info in inputs: print(f"Name: {input_info['name']}, Shape: {input_info['shape']}") print("\nOutputs:") for output_info in outputs: print(f"Name: {output_info['name']}, Shape: {output_info['shape']}") # Path to your .pb file pb_file_path = "resnet50v15_tf.pb" # Load GraphDef graph_def = load_graph_def(pb_file_path) # Get inputs and outputs inputs, outputs = get_graph_inputs_outputs(graph_def) # Print inputs and outputs print_graph_info(inputs, outputs) ``` ## 模型转换 使用tf2onnx工具进行模型转换,详细工具说明可以查看tf2onnx工具官网,tf2onnx项目地址:https://github.com/onnx/tensorflow-onnx 建议大家阅读 tf2onnx 的 README.md 文件,里面有详细的对该工具各个参数的说明。 ``` options: -h, --help show this help message and exit --input INPUT input from graphdef --graphdef GRAPHDEF input from graphdef --saved-model SAVED_MODEL input from saved model --tag TAG tag to use for saved_model --signature_def SIGNATURE_DEF signature_def from saved_model to use --concrete_function CONCRETE_FUNCTION For TF2.x saved_model, index of func signature in __call__ (--signature_def is ignored) --checkpoint CHECKPOINT input from checkpoint --keras KERAS input from keras model --tflite TFLITE input from tflite model --tfjs TFJS input from tfjs model --large_model use the large model format (for models > 2GB) --output OUTPUT output model file --inputs INPUTS model input_names (optional for saved_model, keras, and tflite) --outputs OUTPUTS model output_names (optional for saved_model, keras, and tflite) --ignore_default IGNORE_DEFAULT comma-separated list of names of PlaceholderWithDefault ops to change into Placeholder ops --use_default USE_DEFAULT comma-separated list of names of PlaceholderWithDefault ops to change into Identity ops using their default value --rename-inputs RENAME_INPUTS input names to use in final model (optional) --rename-outputs RENAME_OUTPUTS output names to use in final model (optional) --use-graph-names (saved model only) skip renaming io using signature names --opset OPSET opset version to use for onnx domain --dequantize remove quantization from model. Only supported for tflite currently. --custom-ops CUSTOM_OPS comma-separated map of custom ops to domains in format OpName:domain. Domain 'ai.onnx.converters.tensorflow' is used by default. --extra_opset EXTRA_OPSET extra opset with format like domain:version, e.g. com.microsoft:1 --load_op_libraries LOAD_OP_LIBRARIES comma-separated list of tf op library paths to register before loading model --target {rs4,rs5,rs6,caffe2,tensorrt,nhwc} target platform --continue_on_error continue_on_error --verbose, -v verbose output, option is additive --debug debug mode --output_frozen_graph OUTPUT_FROZEN_GRAPH output frozen tf graph to file --inputs-as-nchw INPUTS_AS_NCHW transpose inputs as from nhwc to nchw --outputs-as-nchw OUTPUTS_AS_NCHW transpose outputs as from nhwc to nchw Usage Examples: python -m tf2onnx.convert --saved-model saved_model_dir --output model.onnx python -m tf2onnx.convert --input frozen_graph.pb --inputs X:0 --outputs output:0 --output model.onnx python -m tf2onnx.convert --checkpoint checkpoint.meta --inputs X:0 --outputs output:0 --output model.onn ``` 下面是将GraphDef格式的模型转换onnx的示例, resnet50v15_tf.pb模型[下载地址](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/003_Atc_Models/modelzoo/Official/cv/Resnet50v1.5_for_ACL/resnet50v15_tf.pb) ```bash python -m tf2onnx.convert --graphdef resnet50v15_tf.pb --output model_nchw.onnx --inputs input_tensor:0 --outputs global_step:0,ArgMax:0,softmax_tensor:0 --inputs-as-nchw input_tensor:0 ``` # tflite转onnx 同样使用tf2onnx工具,例如将ResNet50.tflite模型转为onnx模型,模型[下载地址](https://hf-mirror.com/qualcomm/ResNet50/resolve/main/ResNet50.tflite?download=true): ```bash python -m tf2onnx.convert --opset 16 --tflite ResNet50.tflite --output model.onnx ```