convert.py 900 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torchvision
from torchvision import models



model = models.segmentation.deeplabv3_resnet101(pretrained=True)
model.eval()  # 必须切换到推理模式,关闭 dropout/batchnorm 等训练特有的层

# 2. 定义输入张量(需与模型期望的输入尺寸匹配,DeepLabv3通常为513x513)
input_tensor = torch.randn(1, 3, 513, 513)  # N=1, C=3, H=513, W=513(NCHW格式)

# 3. 导出ONNX模型
onnx_file = "../Resource/Models/deeplabv3_resnet101.onnx"
torch.onnx.export(
    model,                  # 待导出的模型
    input_tensor,           # 示例输入(用于确定计算图结构)
    onnx_file,              # 输出文件路径
    opset_version=12,       # ONNX算子集版本(建议≥11,支持更多算子)
    input_names=["images"], # 输入节点名称(需与后续推理时一致)
    output_names=["output"]# 输出节点名称
)