onnx_optimize.py 2.3 KB
Newer Older
zk's avatar
zk committed
1
2
3
4
import onnx
from onnxsim import simplify
from onnxconverter_common import float16

zk's avatar
zk committed
5
6
7
8
9
onnx_model_path = "../weights/ground_deform.onnx"
sim_model_path = "../weights_opt/ground_deform_opt.onnx"
fp16_model_path = "../weights_opt/ground_deform_opt_fp16.onnx"
fp16_all_model_path = "../weights_opt/ground_deform_opt_fp16_all.onnx"

zk's avatar
zk committed
10
11
custom_op_lib_path = "../ort_plugin_fp16/build/libms_deform_attn_ort.so" 

zk's avatar
zk committed
12
13
14
15
16
17
# # ==========================================
# # 第一步:ONNX Simplify (附带自定义算子库)
# # ==========================================
# print("1️⃣ 正在进行 ONNX Simplify...")
# model = onnx.load(onnx_model_path)
# model_simp, check = simplify(model, custom_lib=custom_op_lib_path)
zk's avatar
zk committed
18

zk's avatar
zk committed
19
20
21
22
23
24
# if check:
#     onnx.save(model_simp, sim_model_path)
#     print(f"✅ Simplify 完成!已保存至 {sim_model_path}")
# else:
#     print("❌ Simplify 验证失败!")
#     exit()
zk's avatar
zk committed
25
26
27
28
29
30
31
32
33
34
35
36
37



# ==========================================
# 第二步:FP16 精度转换 (1.避开自定义算子 2.不避开)
# ==========================================

# 重新加载 sim 后的模型
model_to_fp16 = onnx.load(sim_model_path)
original_cast_nodes = [node.name for node in model_to_fp16.graph.node if node.op_type == "Cast"]
print(f"🔍 查找到 {len(original_cast_nodes)} 个原生 Cast 节点,已全部加入保护名单。")


zk's avatar
zk committed
38
print("\n2️⃣ 正在进行 FP16 混合精度转换...")
zk's avatar
zk committed
39
40
41
42
43
44
45
46
47
48
model_fp16 = float16.convert_float_to_float16(
    model_to_fp16,
    op_block_list=["ms_deform_attn"],     # 屏蔽自定义的注意力算子, 如果是fp32版本自定义算子
    node_block_list=original_cast_nodes,  # 保护所有原生的 Cast 节点
    keep_io_types=True                # 保持整个模型的总输入/输出还是 FP32
)
onnx.save(model_fp16, fp16_model_path)
print(f"✅ FP16 转换完成(避开自定义算子)!已保存至 {fp16_model_path}")


zk's avatar
zk committed
49
50

print("\n2️⃣ 正在进行纯 FP16 精度转换...")
zk's avatar
zk committed
51
52
53
54
55
56
model_fp16_all = float16.convert_float_to_float16(
    model_to_fp16,
    node_block_list=original_cast_nodes,  # 保护所有原生的 Cast 节点
    keep_io_types=True                # 保持整个模型的总输入/输出还是 FP32
)
onnx.save(model_fp16_all, fp16_all_model_path)
zk's avatar
zk committed
57
print(f"✅ 纯 FP16 转换完成!已保存至 {fp16_all_model_path}")
zk's avatar
zk committed
58