#!/usr/bin/env python3 """ 将已导出的ONNX模型转换为支持指定batch_size的模型 """ import onnx import onnx.shape_inference import argparse import os def modify_onnx_batch_size(model_path, output_path, target_batch_size=24): """修改ONNX模型的batch_size Args: model_path: 输入模型路径 output_path: 输出模型路径 target_batch_size: 目标batch_size,-1表示动态batch,其他值表示固定batch """ # 加载模型 model = onnx.load(model_path) # 获取模型输入信息 print(f"原始模型输入信息:") for i, input_info in enumerate(model.graph.input): print(f" Input {i}: {input_info.name}") if input_info.type.tensor_type.HasField("shape"): shape = input_info.type.tensor_type.shape print(f" 原始形状: ", end="") for j, dim in enumerate(shape.dim): if dim.HasField("dim_value"): print(f"{dim.dim_value}", end=" ") elif dim.HasField("dim_param"): print(f"{dim.dim_param}", end=" ") print() # 修改输入形状 for input_info in model.graph.input: if input_info.type.tensor_type.HasField("shape"): shape = input_info.type.tensor_type.shape # 修改第一个维度(batch_size) if len(shape.dim) > 0: if target_batch_size == -1: # 动态batch_size模式 if shape.dim[0].HasField("dim_value"): shape.dim[0].dim_param = "batch_size" shape.dim[0].ClearField("dim_value") elif shape.dim[0].HasField("dim_param"): # 已经是动态维度,保持不变 pass else: # 其他情况,设为动态维度 shape.dim[0].dim_param = "batch_size" else: # 固定batch_size模式 shape.dim[0].dim_value = target_batch_size if shape.dim[0].HasField("dim_param"): shape.dim[0].ClearField("dim_param") # 修改输出形状 for output_info in model.graph.output: if output_info.type.tensor_type.HasField("shape"): shape = output_info.type.tensor_type.shape if len(shape.dim) > 0: if target_batch_size == -1: # 动态batch_size模式 if shape.dim[0].HasField("dim_value"): shape.dim[0].dim_param = "batch_size" shape.dim[0].ClearField("dim_value") else: # 固定batch_size模式 if shape.dim[0].HasField("dim_value"): shape.dim[0].dim_value = target_batch_size elif shape.dim[0].HasField("dim_param"): shape.dim[0].ClearField("dim_param") shape.dim[0].dim_value = target_batch_size # 运行形状推断 model = onnx.shape_inference.infer_shapes(model) # 保存修改后的模型 onnx.save(model, output_path) print(f"模型已保存到: {output_path}") print(f"目标batch_size: {'动态' if target_batch_size == -1 else target_batch_size}") # 验证修改结果 print(f"修改后模型输入信息:") model_modified = onnx.load(output_path) for i, input_info in enumerate(model_modified.graph.input): print(f" Input {i}: {input_info.name}") if input_info.type.tensor_type.HasField("shape"): shape = input_info.type.tensor_type.shape print(f" 修改后形状: ", end="") for j, dim in enumerate(shape.dim): if dim.HasField("dim_value"): print(f"{dim.dim_value}", end=" ") elif dim.HasField("dim_param"): print(f"{dim.dim_param}", end=" ") print() def batch_convert_models(input_dir, output_dir, target_batch_size=24): """批量转换目录中的所有ONNX模型""" if not os.path.exists(output_dir): os.makedirs(output_dir) onnx_files = [f for f in os.listdir(input_dir) if f.endswith('.onnx')] print(f"找到 {len(onnx_files)} 个ONNX文件:") for file in onnx_files: print(f" - {file}") for file in onnx_files: input_path = os.path.join(input_dir, file) output_path = os.path.join(output_dir, file) print(f"\n正在转换: {file}") try: modify_onnx_batch_size(input_path, output_path, target_batch_size) print(f"✓ {file} 转换成功") except Exception as e: print(f"✗ {file} 转换失败: {e}") if __name__ == "__main__": parser = argparse.ArgumentParser(description='修改ONNX模型的batch_size') parser.add_argument('--input', type=str, required=True, help='输入ONNX文件或目录路径') parser.add_argument('--output', type=str, required=True, help='输出路径') parser.add_argument('--batch_size', type=int, default=24, help='目标batch_size(-1表示动态batch)') parser.add_argument('--batch_mode', action='store_true', help='批量模式,处理目录中的所有ONNX文件') args = parser.parse_args() if args.batch_mode: # 批量模式 batch_convert_models(args.input, args.output, args.batch_size) else: # 单个文件模式 modify_onnx_batch_size(args.input, args.output, args.batch_size)