export.py 1.41 KB
Newer Older
change's avatar
change committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from ultralytics import YOLO
import torch
import onnx

# 加载预训练模型
model = YOLO('./model/yolov8n.pt')

# 设置导出参数
batch_size = 4  # 批量大小
input_shape = (batch_size, 3, 640, 640)  # 输入图像尺寸 (C, H, W),这里以 640x640 为例
opset_version = 16  # ONNX 操作集版本
output_file = './model/yolov8n.pt'  # 输出 ONNX 文件路径
# dynamic_axes = {'images': {0: 'batch'}, 'output': {0: 'batch'}}  # 动态批次大小

# 创建随机输入张量
dummy_input = torch.randn(input_shape)

# 导出模型为 ONNX 格式
torch.onnx.export(
    model.model,  # 要导出的模型
    dummy_input,  # 模型的输入
    output_file,  # 输出文件路径
    verbose=False,
    export_params=True,  # 存储已训练参数的值
    opset_version=11,  # ONNX版本,YOLO通常使用11或更高
    do_constant_folding=True,  # 是否执行常量折叠优化
    input_names=['input'],  # 输入节点名称
    output_names=['output'],  # 输出节点名称
    # dynamic_axes={'input': {0: 'batch_size'},  # 批次大小可变
    #               'output': {0: 'batch_size'}}
)

model = onnx.load(output_file)
graph = model.graph
output_to_keep = 'output'
outputs_to_remove = [output for output in graph.output if output.name != output_to_keep]
for output in outputs_to_remove:
    graph.output.remove(output)
onnx.save(model, output_file)
print(f"Model has been successfully exported to {output_file}")