import onnx from onnxsim import simplify from onnxconverter_common import float16 onnx_model_path = "../weights_400x600/ground_deform.onnx" sim_model_path = "../weights_400x600/ground_deform_sim.onnx" fp16_model_path = "../weights_400x600/ground_deform_fp16.onnx" fp16_all_model_path = "../weights_400x600/ground_deform_fp16_all.onnx" custom_op_lib_path = "../ort_plugin_fp16/build/libms_deform_attn_ort.so" # ========================================== # 第一步:ONNX Simplify (附带自定义算子库) # ========================================== print("1️⃣ 正在进行 ONNX Simplify...") model = onnx.load(onnx_model_path) model_simp, check = simplify(model, custom_lib=custom_op_lib_path) if check: onnx.save(model_simp, sim_model_path) print(f"✅ Simplify 完成!已保存至 {sim_model_path}") else: print("❌ Simplify 验证失败!") exit() # ========================================== # 第二步:FP16 精度转换 (1.避开自定义算子 2.不避开) # ========================================== # 重新加载 sim 后的模型 model_to_fp16 = onnx.load(sim_model_path) print("\n2️⃣ 正在进行 FP16 混合精度转换...") original_cast_nodes = [node.name for node in model_to_fp16.graph.node if node.op_type == "Cast"] print(f"🔍 查找到 {len(original_cast_nodes)} 个原生 Cast 节点,已全部加入保护名单。") model_fp16 = float16.convert_float_to_float16( model_to_fp16, op_block_list=["ms_deform_attn"], # 屏蔽自定义的注意力算子, 如果是fp32版本自定义算子 node_block_list=original_cast_nodes, # 保护所有原生的 Cast 节点 keep_io_types=True # 保持整个模型的总输入/输出还是 FP32 ) onnx.save(model_fp16, fp16_model_path) print(f"✅ FP16 转换完成(避开自定义算子)!已保存至 {fp16_model_path}") print("\n2️⃣ 正在进行纯 FP16 精度转换...") model_fp16_all = float16.convert_float_to_float16( model_to_fp16, node_block_list=original_cast_nodes, # 保护所有原生的 Cast 节点 keep_io_types=True # 保持整个模型的总输入/输出还是 FP32 ) onnx.save(model_fp16_all, fp16_all_model_path) print(f"✅ FP16 转换完成!已保存至 {fp16_all_model_path}")