import onnx
from onnxsim import simplify
import os


def convert_onnx_dynamic_to_static(onnx_model_path, output_path, batch_size):
    model = onnx.load(onnx_model_path)
    for input in model.graph.input:
        if input.type.tensor_type.HasField('shape'):
            for dim in input.type.tensor_type.shape.dim:
                if dim.dim_value == 0:  # 动态维度通常是 0
                    dim.dim_value = batch_size
    model_simp, check = simplify(model)
    assert check, "Simplified ONNX model could not be validated"
    
    # 保存简化后的模型
    onnx.save(model_simp, output_path)
    print(f"Simplified and static shape model saved to {output_path}")
    

if __name__ == '__main__':
    
    model = "model_1"
    model_dir = "./models"
    onnx_model_path = os.path.join(model_dir, f'{model}/onnx-1/model.onnx')
    
    
    batch_size = [1,2,4,8,16,32,64,128,256,512,1024,2048]
    
    for bs in batch_size:
        static_output_path = os.path.join(model_dir, f'{model}/onnx-1/model-static-batch-size-{bs}.onnx')       
        convert_onnx_dynamic_to_static(onnx_model_path, static_output_path, bs)