convert_onnx_batch_size.py 5.54 KB
Newer Older
sunzhq2's avatar
sunzhq2 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
"""
将已导出的ONNX模型转换为支持指定batch_size的模型
"""

import onnx
import onnx.shape_inference
import argparse
import os

def modify_onnx_batch_size(model_path, output_path, target_batch_size=24):
    """修改ONNX模型的batch_size
    
    Args:
        model_path: 输入模型路径
        output_path: 输出模型路径
        target_batch_size: 目标batch_size,-1表示动态batch,其他值表示固定batch
    """
    
    # 加载模型
    model = onnx.load(model_path)
    
    # 获取模型输入信息
    print(f"原始模型输入信息:")
    for i, input_info in enumerate(model.graph.input):
        print(f"  Input {i}: {input_info.name}")
        if input_info.type.tensor_type.HasField("shape"):
            shape = input_info.type.tensor_type.shape
            print(f"    原始形状: ", end="")
            for j, dim in enumerate(shape.dim):
                if dim.HasField("dim_value"):
                    print(f"{dim.dim_value}", end=" ")
                elif dim.HasField("dim_param"):
                    print(f"{dim.dim_param}", end=" ")
            print()
    
    # 修改输入形状
    for input_info in model.graph.input:
        if input_info.type.tensor_type.HasField("shape"):
            shape = input_info.type.tensor_type.shape
            # 修改第一个维度(batch_size)
            if len(shape.dim) > 0:
                if target_batch_size == -1:
                    # 动态batch_size模式
                    if shape.dim[0].HasField("dim_value"):
                        shape.dim[0].dim_param = "batch_size"
                        shape.dim[0].ClearField("dim_value")
                    elif shape.dim[0].HasField("dim_param"):
                        # 已经是动态维度,保持不变
                        pass
                    else:
                        # 其他情况,设为动态维度
                        shape.dim[0].dim_param = "batch_size"
                else:
                    # 固定batch_size模式
                    shape.dim[0].dim_value = target_batch_size
                    if shape.dim[0].HasField("dim_param"):
                        shape.dim[0].ClearField("dim_param")
    
    # 修改输出形状
    for output_info in model.graph.output:
        if output_info.type.tensor_type.HasField("shape"):
            shape = output_info.type.tensor_type.shape
            if len(shape.dim) > 0:
                if target_batch_size == -1:
                    # 动态batch_size模式
                    if shape.dim[0].HasField("dim_value"):
                        shape.dim[0].dim_param = "batch_size"
                        shape.dim[0].ClearField("dim_value")
                else:
                    # 固定batch_size模式
                    if shape.dim[0].HasField("dim_value"):
                        shape.dim[0].dim_value = target_batch_size
                    elif shape.dim[0].HasField("dim_param"):
                        shape.dim[0].ClearField("dim_param")
                        shape.dim[0].dim_value = target_batch_size
    
    # 运行形状推断
    model = onnx.shape_inference.infer_shapes(model)
    
    # 保存修改后的模型
    onnx.save(model, output_path)
    
    print(f"模型已保存到: {output_path}")
    print(f"目标batch_size: {'动态' if target_batch_size == -1 else target_batch_size}")
    
    # 验证修改结果
    print(f"修改后模型输入信息:")
    model_modified = onnx.load(output_path)
    for i, input_info in enumerate(model_modified.graph.input):
        print(f"  Input {i}: {input_info.name}")
        if input_info.type.tensor_type.HasField("shape"):
            shape = input_info.type.tensor_type.shape
            print(f"    修改后形状: ", end="")
            for j, dim in enumerate(shape.dim):
                if dim.HasField("dim_value"):
                    print(f"{dim.dim_value}", end=" ")
                elif dim.HasField("dim_param"):
                    print(f"{dim.dim_param}", end=" ")
            print()

def batch_convert_models(input_dir, output_dir, target_batch_size=24):
    """批量转换目录中的所有ONNX模型"""
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    onnx_files = [f for f in os.listdir(input_dir) if f.endswith('.onnx')]
    
    print(f"找到 {len(onnx_files)} 个ONNX文件:")
    for file in onnx_files:
        print(f"  - {file}")
    
    for file in onnx_files:
        input_path = os.path.join(input_dir, file)
        output_path = os.path.join(output_dir, file)
        
        print(f"\n正在转换: {file}")
        try:
            modify_onnx_batch_size(input_path, output_path, target_batch_size)
            print(f"✓ {file} 转换成功")
        except Exception as e:
            print(f"✗ {file} 转换失败: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='修改ONNX模型的batch_size')
    parser.add_argument('--input', type=str, required=True, help='输入ONNX文件或目录路径')
    parser.add_argument('--output', type=str, required=True, help='输出路径')
    parser.add_argument('--batch_size', type=int, default=24, help='目标batch_size(-1表示动态batch)')
    parser.add_argument('--batch_mode', action='store_true', help='批量模式,处理目录中的所有ONNX文件')
    
    args = parser.parse_args()
    
    if args.batch_mode:
        # 批量模式
        batch_convert_models(args.input, args.output, args.batch_size)
    else:
        # 单个文件模式
        modify_onnx_batch_size(args.input, args.output, args.batch_size)