export_onnx.py 1.11 KB
Newer Older
yaoht's avatar
yaoht committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torchvision.models as models

# 定义模型结构(以 ResNet-50 为例)
model = models.resnet50() # 或者使用自己定义的模型实例
model.load_state_dict(torch.load('resnet50-0676ba61.pth'))
model.eval()  # 设置为评估模式

# 示例输入
input_tensor = torch.randn(1, 3, 224, 224)  # 修改为你的输入张量形状

# Export the model
torch.onnx.export(model,               # 需要转换的pytorch模型变量
                  input_tensor,                         # 模型示例输入 (多输入情况为多变量tuple)
                  "resnet50.onnx",   # 导出模型文件名
                  export_params=True,        # 导出权重参数
                  opset_version=10,          # 指定ONNX的opset版本(可选)
                  do_constant_folding=True,  # 是否常量折叠(可选)
                  input_names = ['input'],   # 模型输入的names
                  output_names = ['output'], # 模型输出的names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # 动态维度指定
                                'output' : {0 : 'batch_size'}})