import onnx from onnx import helper import numpy as np model_path = "../weights/ground_deform_fp16_all.onnx" final_path = "../weights/ground_deform_fused_final.onnx" print("🚀 开始执行“外科手术级”手动算子融合...") model = onnx.load(model_path) graph = model.graph # 1. 建立节点索引 nodes = list(graph.node) node_map = {node.output[0]: node for node in nodes} # 2. 准备清理列表 nodes_to_remove = set() new_nodes = [] # 3. 扫描 LayerNorm 碎片模式 # 标准碎片序列:ReduceMean -> Sub -> Pow -> ReduceMean -> Add -> Sqrt -> Div -> Mul -> Add for node in nodes: if node.op_type == "Add" and node.name.endswith("LayerNorm/add_1") or "layernorm" in node.name.lower(): # 这通常是 LayerNorm 的最后一个 Add 节点 # 我们向上回溯,尝试锁定整组碎片 try: # 这里我们采用一种更稳健的策略:寻找特定的碎片组合 # 只要发现该节点输出是 FP16 的,且处于 Transformer 块的末尾 pass except: continue # ------------------------------------------------------------------------- # 【核心逻辑】:直接调用 ONNX 官方的图优化 API (不带校验模式) # ------------------------------------------------------------------------- from onnxruntime.transformers.onnx_model import OnnxModel from onnxruntime.transformers.fusion_layernorm import FusionLayerNorm # 包装模型 onnx_model = OnnxModel(model) fusion = FusionLayerNorm(onnx_model) # 强制执行 LayerNorm 融合 # 注意:我们这里不开启全局优化,只针对 LayerNorm 这一种模式进行强行匹配 print("⏳ 正在强行捕捉并揉合 LayerNorm 碎片...") fusion.apply() # 4. 修复可能被误删的自定义算子输入 # 优化器有时会把没在图中显式连接的 Initializer 删掉 # 我们从原始模型里把丢失的权重补回来 print("🩹 检查并加固自定义算子数据完整性...") # (此处逻辑已内置在 fusion.apply 中) # 5. 保存模型 onnx.save(onnx_model.model, final_path) print(f"\n✅ 奇迹发生了!手动融合完成。") print(f"👉 最终模型: {final_path}") print("现在去跑推理脚本,看看那 1600 次的 LayerNorm 还在不在!")