import onnx from onnxsim import simplify from onnxconverter_common import float16 onnx_model_path = "../weights/ground_deform.onnx" sim_model_path = "../weights_opt/ground_deform_opt.onnx" fp16_model_path = "../weights_opt/ground_deform_opt_fp16.onnx" fp16_all_model_path = "../weights_opt/ground_deform_opt_fp16_all.onnx" custom_op_lib_path = "../ort_plugin_fp16/build/libms_deform_attn_ort.so" # # ========================================== # # 第一步:ONNX Simplify (附带自定义算子库) # # ========================================== # print("1️⃣ 正在进行 ONNX Simplify...") # model = onnx.load(onnx_model_path) # model_simp, check = simplify(model, custom_lib=custom_op_lib_path) # if check: # onnx.save(model_simp, sim_model_path) # print(f"✅ Simplify 完成!已保存至 {sim_model_path}") # else: # print("❌ Simplify 验证失败!") # exit() # ========================================== # 第二步:FP16 精度转换 (1.避开自定义算子 2.不避开) # ========================================== # 重新加载 sim 后的模型 model_to_fp16 = onnx.load(sim_model_path) original_cast_nodes = [node.name for node in model_to_fp16.graph.node if node.op_type == "Cast"] print(f"🔍 查找到 {len(original_cast_nodes)} 个原生 Cast 节点,已全部加入保护名单。") print("\n2️⃣ 正在进行 FP16 混合精度转换...") model_fp16 = float16.convert_float_to_float16( model_to_fp16, op_block_list=["ms_deform_attn"], # 屏蔽自定义的注意力算子, 如果是fp32版本自定义算子 node_block_list=original_cast_nodes, # 保护所有原生的 Cast 节点 keep_io_types=True # 保持整个模型的总输入/输出还是 FP32 ) onnx.save(model_fp16, fp16_model_path) print(f"✅ FP16 转换完成(避开自定义算子)!已保存至 {fp16_model_path}") print("\n2️⃣ 正在进行纯 FP16 精度转换...") model_fp16_all = float16.convert_float_to_float16( model_to_fp16, node_block_list=original_cast_nodes, # 保护所有原生的 Cast 节点 keep_io_types=True # 保持整个模型的总输入/输出还是 FP32 ) onnx.save(model_fp16_all, fp16_all_model_path) print(f"✅ 纯 FP16 转换完成!已保存至 {fp16_all_model_path}")