# Pytorch模型转换onnx模型 ## 环境准备 确保环境中有pytorch,可以用pip安装pytorch,具体可以参考PyTorch官网https://pytorch.org/get-started/locally/ ```bash pip install torch ``` ## 模型确认 首先,请辨别手里的Pytorch模型是权重数据文件.pth还是包含模型结构以及权重数据的.pt文件,通常惯例是: - **`.pth` 文件** 通常用于保存模型的权重(`state_dict`)。 - **`.pt` 文件** 通常用于保存整个模型(包括模型结构和权重)。 然而,这只是惯例,实际使用中两者可以互换,也取决于保存文件的人如何命名的,甚至可以是其他的后缀名。因此,为了准确加载模型,你需要知道该文件具体保存了什么。下面python脚本可以判断模型文件是否是完整模型文件。 ```python import torch # 尝试加载 .pt 文件 try: model_data = torch.load('resnet50.pt') print(type(model_data)) # 打印数据类型 if isinstance(model_data, dict): print(model_data.keys()) # 如果是字典,打印键 print("The .pt file is weights file) else: print("The .pt file contains the complete model.") except Exception as e: print(f"Error loading model: {e}") ``` ## 导出模型 一般情况下,使用保存的权重文件,这里示例模型[下载地址](https://download.pytorch.org/models/resnet50-0676ba61.pth),转换模型用torch.onnx.export函数,具体操作如下。 其中,如果想导出动态模型,可以通过设置dynamic_axes,这里设置dynamic_axes={'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}},即为将输入和输出tensor的名称为batch_size的索引为0的维度设置为动态模式,导出的模型输入shape将会变为[batch_size, 3, 224, 224]。同理,如果设置为dynamic_axes={'input' : {0 : 'batch_size', 2: 'height', 3: 'width'}, 'output' : {0 : 'batch_size', 2: 'height', 3: 'width'}},导出的模型输入输出shape将为变为[batch_size, 3, height, width]。 ```python import torch import torchvision.models as models # 定义模型结构(以 ResNet-50 为例) model = models.resnet50() # 或者使用自己定义的模型实例 model.load_state_dict(torch.load('resnet50-0676ba61.pth')) model.eval() # 设置为评估模式 # 示例输入 input_tensor = torch.randn(1, 3, 224, 224) # 修改为你的输入张量形状 # Export the model torch.onnx.export(model, # 需要转换的pytorch模型变量 input_tensor, # 模型示例输入 (多输入情况为多变量tuple) "resnet50.onnx", # 导出模型文件名 export_params=True, # 导出权重参数 opset_version=10, # 指定ONNX的opset版本(可选) do_constant_folding=True, # 是否常量折叠(可选) input_names = ['input'], # 模型输入的names output_names = ['output'], # 模型输出的names dynamic_axes={'input' : {0 : 'batch_size'}, # 动态维度指定 'output' : {0 : 'batch_size'}}) ``` 运行脚本即可得到onnx模型,如果想导出自己的onnx模型,自行做相应修改。 ```bash python export_onnx.py ```