onnx_optimize.py 2.28 KB
Newer Older
zk's avatar
zk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import onnx
from onnxsim import simplify
from onnxconverter_common import float16

onnx_model_path = "../weights_400x600/ground_deform.onnx"
sim_model_path = "../weights_400x600/ground_deform_sim.onnx"
fp16_model_path = "../weights_400x600/ground_deform_fp16.onnx"
fp16_all_model_path = "../weights_400x600/ground_deform_fp16_all.onnx"
custom_op_lib_path = "../ort_plugin_fp16/build/libms_deform_attn_ort.so" 

# ==========================================
# 第一步:ONNX Simplify (附带自定义算子库)
# ==========================================
print("1️⃣ 正在进行 ONNX Simplify...")
model = onnx.load(onnx_model_path)
model_simp, check = simplify(model, custom_lib=custom_op_lib_path)

if check:
    onnx.save(model_simp, sim_model_path)
    print(f"✅ Simplify 完成!已保存至 {sim_model_path}")
else:
    print("❌ Simplify 验证失败!")
    exit()



# ==========================================
# 第二步:FP16 精度转换 (1.避开自定义算子 2.不避开)
# ==========================================

# 重新加载 sim 后的模型
model_to_fp16 = onnx.load(sim_model_path)
print("\n2️⃣ 正在进行 FP16 混合精度转换...")

original_cast_nodes = [node.name for node in model_to_fp16.graph.node if node.op_type == "Cast"]
print(f"🔍 查找到 {len(original_cast_nodes)} 个原生 Cast 节点,已全部加入保护名单。")


model_fp16 = float16.convert_float_to_float16(
    model_to_fp16,
    op_block_list=["ms_deform_attn"],     # 屏蔽自定义的注意力算子, 如果是fp32版本自定义算子
    node_block_list=original_cast_nodes,  # 保护所有原生的 Cast 节点
    keep_io_types=True                # 保持整个模型的总输入/输出还是 FP32
)

onnx.save(model_fp16, fp16_model_path)
print(f"✅ FP16 转换完成(避开自定义算子)!已保存至 {fp16_model_path}")

print("\n2️⃣ 正在进行纯 FP16 精度转换...")

model_fp16_all = float16.convert_float_to_float16(
    model_to_fp16,
    node_block_list=original_cast_nodes,  # 保护所有原生的 Cast 节点
    keep_io_types=True                # 保持整个模型的总输入/输出还是 FP32
)

onnx.save(model_fp16_all, fp16_all_model_path)
print(f"✅ FP16 转换完成!已保存至 {fp16_all_model_path}")