# Pytorch模型转换onnx模型

## 环境准备

确保环境中有pytorch，可以用pip安装pytorch，具体可以参考PyTorch官网https://pytorch.org/get-started/locally/

```bash
pip install torch
```

## 模型确认

首先，请辨别手里的Pytorch模型是权重数据文件.pth还是包含模型结构以及权重数据的.pt文件，通常惯例是：

- **`.pth` 文件** 通常用于保存模型的权重（`state_dict`）。
- **`.pt` 文件** 通常用于保存整个模型（包括模型结构和权重）。

然而，这只是惯例，实际使用中两者可以互换，也取决于保存文件的人如何命名的，甚至可以是其他的后缀名。因此，为了准确加载模型，你需要知道该文件具体保存了什么。下面python脚本可以判断模型文件是否是完整模型文件。

```python
import torch

# 尝试加载 .pt 文件
try:
    model_data = torch.load('resnet50.pt')
    print(type(model_data))  # 打印数据类型
    if isinstance(model_data, dict):
        print(model_data.keys())  # 如果是字典，打印键
        print("The .pt file is weights file)
    else:
        print("The .pt file contains the complete model.")
except Exception as e:
    print(f"Error loading model: {e}")

```

## 导出模型

一般情况下，使用保存的权重文件，这里示例模型[下载地址](https://download.pytorch.org/models/resnet50-0676ba61.pth)，转换模型用torch.onnx.export函数，具体操作如下。

其中，如果想导出动态模型，可以通过设置dynamic_axes，这里设置dynamic_axes={'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}}，即为将输入和输出tensor的名称为batch_size的索引为0的维度设置为动态模式，导出的模型输入shape将会变为[batch_size, 3, 224, 224]。同理，如果设置为dynamic_axes={'input' : {0 : 'batch_size', 2: 'height', 3: 'width'}, 'output' : {0 : 'batch_size', 2: 'height', 3: 'width'}}，导出的模型输入输出shape将为变为[batch_size, 3, height, width]。

```python
import torch
import torchvision.models as models

# 定义模型结构（以 ResNet-50 为例）
model = models.resnet50() # 或者使用自己定义的模型实例
model.load_state_dict(torch.load('resnet50-0676ba61.pth'))
model.eval()  # 设置为评估模式

# 示例输入
input_tensor = torch.randn(1, 3, 224, 224)  # 修改为你的输入张量形状

# Export the model
torch.onnx.export(model,                     # 需要转换的pytorch模型变量
                  input_tensor,              # 模型示例输入 (多输入情况为多变量tuple)
                  "resnet50.onnx",           # 导出模型文件名
                  export_params=True,        # 导出权重参数
                  opset_version=10,          # 指定ONNX的opset版本（可选）
                  do_constant_folding=True,  # 是否常量折叠（可选）
                  input_names = ['input'],   # 模型输入的names
                  output_names = ['output'], # 模型输出的names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # 动态维度指定
                                'output' : {0 : 'batch_size'}})
```

运行脚本即可得到onnx模型，如果想导出自己的onnx模型，自行做相应修改。

```bash
python export_onnx.py
```

