import onnx from onnxsim import simplify import os def convert_onnx_dynamic_to_static(onnx_model_path, output_path, batch_size): model = onnx.load(onnx_model_path) for input in model.graph.input: if input.type.tensor_type.HasField('shape'): for dim in input.type.tensor_type.shape.dim: if dim.dim_value == 0: # 动态维度通常是 0 dim.dim_value = batch_size model_simp, check = simplify(model) assert check, "Simplified ONNX model could not be validated" # 保存简化后的模型 onnx.save(model_simp, output_path) print(f"Simplified and static shape model saved to {output_path}") if __name__ == '__main__': model = "model_1" model_dir = "./models" onnx_model_path = os.path.join(model_dir, f'{model}/onnx-1/model.onnx') batch_size = [1,2,4,8,16,32,64,128,256,512,1024,2048] for bs in batch_size: static_output_path = os.path.join(model_dir, f'{model}/onnx-1/model-static-batch-size-{bs}.onnx') convert_onnx_dynamic_to_static(onnx_model_path, static_output_path, bs)