Commit a1865640 authored by zk's avatar zk
Browse files

新增migraphx部分

parent 0896d47e
......@@ -157,7 +157,7 @@ python onnx_inference_deform_optim.py
如需使用更低分辨率的图像输入(如 400x800)以进一步加速推理,可按以下步骤操作:
### 6.1 修改导出脚本
1. 修改导出脚本
编辑 `deform_ort/export_onnx_deform.py`,修改图像尺寸与导出路径:
......@@ -169,7 +169,7 @@ img = torch.randn(1, 3, 400, 800).to(device)
onnx_output_path = "../weights_400x800/ground_deform.onnx"
```
### 6.2 正常导出并量化
2. 正常导出并量化
```bash
cd deform_ort
......@@ -177,7 +177,7 @@ python export_onnx_deform.py
python onnx_optimize.py
```
### 6.3 修改推理预处理分辨率
3. 修改推理预处理分辨率
编辑 `groundingdino/util/inference.py` 中的 `load_image` 函数,将 `RandomResize` 的参数从 800 改为 400:
......@@ -186,7 +186,7 @@ python onnx_optimize.py
T.RandomResize([400], max_size=1333),
```
### 6.4. 执行 ORT 推理
4. 执行 ORT 推理
运行推理脚本,并确保代码中的 ONNX 模型路径指向 `weights_400x800/` 下对应的模型文件:
......@@ -198,7 +198,26 @@ python onnx_inference_deform_optim.py
-----
## 7\. 测试结果对比
## 7\. migraphx推理
1. 进入migraphx_infer文件夹
```bash
cd migraphx_infer
```
2. 运行转换onnx脚本
将简化后的onnx转换为要用migraphx推理的onnx
```bash
bash migraphx_export.bash
```
3. 如果已经得到了mxr文件,直接测试
```bash
bash migraphx_perf.bash
```
-----
## 8\. 测试结果对比
*以下测试均包含 5 轮预热(Warmup)和 10 轮正式测试。*
......@@ -208,7 +227,7 @@ python onnx_inference_deform_optim.py
> * **模型文件**:默认存放于 `../weights/` 目录下。
> * **自定义算子目录**:对应的完整动态库路径均为 `../[目录名]/build/libms_deform_attn_ort.so`。
### 7.1 BW150 测试结果
### 8.1 ORT BW150 测试结果
单张 BW150 卡,图像输入 800x1200,Batch Size = 1
......@@ -221,7 +240,7 @@ python onnx_inference_deform_optim.py
| **ORT + Plugin** | +自定义算子<br>+FP16 纯量化方案 B | `ground_deform_fp16_all.onnx` | `ort_plugin_fp16_B` | 87.34 | 11.44 |
| **ORT + Plugin** | +自定义算子<br>+FP16 极致优化方案 C | `ground_deform_fp16_all.onnx` | `ort_plugin_fp16_C` | 84.52 | 11.82 |
### 7.2 BW100 测试结果
### 8.2 ORT BW100 测试结果
单张 BW100 卡,图像输入 800x1200,Batch Size = 1
......@@ -233,11 +252,22 @@ python onnx_inference_deform_optim.py
| **ORT + Plugin** | +自定义算子<br>+FP16 纯量化方案 B | `ground_deform_fp16_all.onnx` | `ort_plugin_fp16_B` | 105.35 | 9.49 |
| **ORT + Plugin** | +自定义算子<br>+FP16 极致优化方案 C | `ground_deform_fp16_all.onnx` | `ort_plugin_fp16_C` | 100.91 | 9.90 |
### 8.3 migraphx BW100 测试结果
```
Batch size: 1
Rate: 6.05197 inferences/sec
Total time: 165.235ms (Min: 165.115ms, Max: 165.535ms,
Mean: 165.258ms, Median: 165.225ms)
Percentiles (90%, 95%, 99%): (165.358ms, 165.358ms, 165.358ms)
Total instructions time: 205.275ms
Overhead time: 2.32812ms, -40.0399ms
Overhead: 1%, -24%
```
-----
## 参考项目
本项目在开发过程中参考了以下优秀开源项目,在此表示感谢
本项目在开发过程中参考了以下开源项目:
- [**GroundingDINO**](https://github.com/IDEA-Research/GroundingDINO) - GroundingDINO 官方仓库,提供基础模型与算法实现。
- [**GroundingDINO-TensorRT-and-ONNX-Inference**](https://github.com/wingdzero/GroundingDINO-TensorRT-and-ONNX-Inference) - 提供了 GroundingDINO 的 TensorRT 及 ONNX 推理部署参考实现。
\ No newline at end of file
......@@ -7,20 +7,35 @@ import onnxruntime as ort
import bisect
import time
import os
from typing import Tuple
import groundingdino.datasets.transforms as T
from PIL import Image
"""
针对模型前后处理和代码结构进行优化
1.预测结果获取优化prediction_logits = sigmoid(outputs[0][0])
2.输入数据提前获取直接传入,移除了对tokenizer的依赖
"""
from groundingdino.util.inference import load_image
so_options = ort.SessionOptions()
custom_op_lib_path = "../ort_plugin/build/libms_deform_attn_ort.so"
custom_op_lib_path = "../ort_plugin_fp16_C/build/libms_deform_attn_ort.so"
so_options.register_custom_ops_library(custom_op_lib_path)
# 开启ort优化
so_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
# T.RandomResize([400], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def sigmoid(x):
return 1 / (1 + np.exp(-x))
......@@ -180,7 +195,7 @@ def benchmark_performance(
if __name__ == '__main__':
# 配置参数
model_path = '../weights_400x600/ground_deform.onnx'
model_path = '../weights/ground_deform_fp16_all.onnx'
"""
../weights/ground_deform.onnx 普通版本
../weights/ground_deform_sim.onnx 简化版本
......@@ -264,6 +279,6 @@ if __name__ == '__main__':
)
# 保存结果
cv2.imwrite('./result.jpg', ori_img)
print(f"\n✅ 结果已保存至: ./result.jpg")
cv2.imwrite('../weights/result.jpg', ori_img)
print(f"\n✅ 结果已保存至: ../weights/result.jpg")
print(f"✅ 检测到目标: {phrases} (共 {len(boxes)} 个)")
......@@ -7,6 +7,9 @@ import onnxruntime as ort
import bisect
import time
import os
from typing import Tuple
import groundingdino.datasets.transforms as T
from PIL import Image
"""
针对模型前后处理和代码结构进行优化
1.预测结果获取优化prediction_logits = sigmoid(outputs[0][0])
......@@ -14,14 +17,28 @@ import os
3.IO binding优化
"""
from groundingdino.util.inference import load_image
so_options = ort.SessionOptions()
custom_op_lib_path = "../ort_plugin/build/libms_deform_attn_ort.so"
# 如何想要查看ORT的详细日志,可以取消下面这行的注释,并设置合适的日志级别
# so_options.enable_profiling = True
custom_op_lib_path = "../ort_plugin_fp16_C/build/libms_deform_attn_ort.so"
so_options.register_custom_ops_library(custom_op_lib_path)
# 开启ort优化
so_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
# T.RandomResize([400], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def sigmoid(x):
return 1 / (1 + np.exp(-x))
......@@ -67,20 +84,17 @@ def predict(
t0 = time.time()
# 1. 仅仅绑定当前这帧发生变化的图片其他文本输入早就在显存里躺好了
# 1. 仅仅绑定当前这帧发生变化的图片其他文本输入绑定好了
img_tensor = np.expand_dims(np.asarray(image), axis=0)
# 尝试输入进行fp16转换,导出onnx时输入转换为fp16,但是推理性能下降了
# img_tensor = np.expand_dims(np.asarray(image), axis=0).astype(np.float16)
io_binding.bind_cpu_input('img', img_tensor)
# 2. 绑定需要获取的输出
io_binding.bind_output('logits')
io_binding.bind_output('boxes')
# 3. 极速执行推理
# 2. 执行推理
ort_session.run_with_iobinding(io_binding)
ort_outputs = io_binding.copy_outputs_to_cpu()
# 清空输出绑定,否则下一次循环会内存泄漏报错
io_binding.clear_binding_outputs()
# 3. 结果从GPU 复制回 CPU
ort_outputs = io_binding.copy_outputs_to_cpu()
infer_time = time.time() - t0
if not is_benchmark:
......@@ -204,7 +218,7 @@ def benchmark_performance(
if __name__ == '__main__':
# 配置参数
model_path = '../weights/ground_deform_fp16.onnx'
model_path = '../weights_opt/ground_deform_opt_fp16_all.onnx'
img_path = '../images/in/car_1.jpg'
TEXT_PROMPT = "car ."
BOX_TRESHOLD = 0.35
......@@ -241,6 +255,9 @@ if __name__ == '__main__':
for key in static_keys:
io_binding.bind_cpu_input(key, TEXT_CACHE[key])
io_binding.bind_output('logits')
io_binding.bind_output('boxes')
# 第一步:运行完整的性能测试(预热+实际推理)
performance_result = benchmark_performance(
ort_session, io_binding, image, TEXT_CACHE,
......@@ -281,6 +298,6 @@ if __name__ == '__main__':
)
# 保存结果
cv2.imwrite('./images/out/result.jpg', ori_img)
print(f"\n✅ 结果已保存至: ./images/out/result.jpg")
cv2.imwrite('../images/out/result.jpg', ori_img)
print(f"\n✅ 结果已保存至: ../images/out/result.jpg")
print(f"✅ 检测到目标: {phrases} (共 {len(boxes)} 个)")
......@@ -2,25 +2,26 @@ import onnx
from onnxsim import simplify
from onnxconverter_common import float16
onnx_model_path = "../weights_400x600/ground_deform.onnx"
sim_model_path = "../weights_400x600/ground_deform_sim.onnx"
fp16_model_path = "../weights_400x600/ground_deform_fp16.onnx"
fp16_all_model_path = "../weights_400x600/ground_deform_fp16_all.onnx"
onnx_model_path = "../weights/ground_deform.onnx"
sim_model_path = "../weights_opt/ground_deform_opt.onnx"
fp16_model_path = "../weights_opt/ground_deform_opt_fp16.onnx"
fp16_all_model_path = "../weights_opt/ground_deform_opt_fp16_all.onnx"
custom_op_lib_path = "../ort_plugin_fp16/build/libms_deform_attn_ort.so"
# ==========================================
# 第一步:ONNX Simplify (附带自定义算子库)
# ==========================================
print("1️⃣ 正在进行 ONNX Simplify...")
model = onnx.load(onnx_model_path)
model_simp, check = simplify(model, custom_lib=custom_op_lib_path)
# # ==========================================
# # 第一步:ONNX Simplify (附带自定义算子库)
# # ==========================================
# print("1️⃣ 正在进行 ONNX Simplify...")
# model = onnx.load(onnx_model_path)
# model_simp, check = simplify(model, custom_lib=custom_op_lib_path)
if check:
onnx.save(model_simp, sim_model_path)
print(f"✅ Simplify 完成!已保存至 {sim_model_path}")
else:
print("❌ Simplify 验证失败!")
exit()
# if check:
# onnx.save(model_simp, sim_model_path)
# print(f"✅ Simplify 完成!已保存至 {sim_model_path}")
# else:
# print("❌ Simplify 验证失败!")
# exit()
......@@ -30,30 +31,28 @@ else:
# 重新加载 sim 后的模型
model_to_fp16 = onnx.load(sim_model_path)
print("\n2️⃣ 正在进行 FP16 混合精度转换...")
original_cast_nodes = [node.name for node in model_to_fp16.graph.node if node.op_type == "Cast"]
print(f"🔍 查找到 {len(original_cast_nodes)} 个原生 Cast 节点,已全部加入保护名单。")
print("\n2️⃣ 正在进行 FP16 混合精度转换...")
model_fp16 = float16.convert_float_to_float16(
model_to_fp16,
op_block_list=["ms_deform_attn"], # 屏蔽自定义的注意力算子, 如果是fp32版本自定义算子
node_block_list=original_cast_nodes, # 保护所有原生的 Cast 节点
keep_io_types=True # 保持整个模型的总输入/输出还是 FP32
)
onnx.save(model_fp16, fp16_model_path)
print(f"✅ FP16 转换完成(避开自定义算子)!已保存至 {fp16_model_path}")
print("\n2️⃣ 正在进行纯 FP16 精度转换...")
print("\n2️⃣ 正在进行纯 FP16 精度转换...")
model_fp16_all = float16.convert_float_to_float16(
model_to_fp16,
node_block_list=original_cast_nodes, # 保护所有原生的 Cast 节点
keep_io_types=True # 保持整个模型的总输入/输出还是 FP32
)
onnx.save(model_fp16_all, fp16_all_model_path)
print(f"✅ FP16 转换完成!已保存至 {fp16_all_model_path}")
print(f"✅ FP16 转换完成!已保存至 {fp16_all_model_path}")
import json
import sys
from collections import defaultdict
def analyze_profile(json_path):
print(f"🔍 正在解析性能文件: {json_path}\n")
with open(json_path, 'r') as f:
data = json.load(f)
# 兼容不同的 JSON 根节点格式
events = data if isinstance(data, list) else data.get('traceEvents', [])
# 按“算子类型”(如 MatMul, Conv) 统计总耗时
op_type_times = defaultdict(float)
# 按“具体节点名”(如 /transformer/encoder/MatMul_1) 统计总耗时
node_name_times = defaultdict(float)
total_inference_time = 0.0
for event in events:
# 只统计包含持续时间(dur)和参数(args)的事件
if 'dur' in event and 'args' in event:
args = event['args']
# ORT 通常把算子类型记录在 args 里的 op_name
if 'op_name' in args:
op_type = args['op_name']
# event['name'] 通常包含完整的节点路径
node_name = event.get('name', 'Unknown_Node')
# JSON 里的 dur 单位是微秒 (microseconds),转成毫秒 (ms)
dur_ms = event['dur'] / 1000.0
op_type_times[op_type] += dur_ms
node_name_times[node_name] += dur_ms
total_inference_time += dur_ms
# 对字典进行降序排序
sorted_op_types = sorted(op_type_times.items(), key=lambda x: x[1], reverse=True)
sorted_nodes = sorted(node_name_times.items(), key=lambda x: x[1], reverse=True)
print("="*50)
print("🏆 按【算子类型 (OpType)】耗时总和排名 Top 10")
print("="*50)
for i, (op, time_ms) in enumerate(sorted_op_types[:10]):
percentage = (time_ms / total_inference_time) * 100 if total_inference_time > 0 else 0
print(f"{i+1:2d}. {op:<20} | 耗时: {time_ms:>8.3f} ms | 占比: {percentage:>5.2f}%")
print("\n" + "="*50)
print("🎯 按【单个具体节点 (Node)】耗时排名 Top 15")
print("="*50)
for i, (node, time_ms) in enumerate(sorted_nodes[:15]):
percentage = (time_ms / total_inference_time) * 100 if total_inference_time > 0 else 0
print(f"{i+1:2d}. 耗时: {time_ms:>8.3f} ms ({percentage:>5.2f}%) | 节点: {node}")
if __name__ == "__main__":
# 把这里换成你刚刚生成的 json 文件名
profile_file = "./onnxruntime_profile__2026-04-27_13-58-17.json"
if len(sys.argv) > 1:
profile_file = sys.argv[1]
analyze_profile(profile_file)
\ No newline at end of file
deform_ort/result.jpg

1.35 MB | W: | H:

deform_ort/result.jpg

1.35 MB | W: | H:

deform_ort/result.jpg
deform_ort/result.jpg
deform_ort/result.jpg
deform_ort/result.jpg
  • 2-up
  • Swipe
  • Onion skin
......@@ -39,8 +39,8 @@ def load_model(model_config_path: str, model_checkpoint_path: str, device: str =
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
# T.RandomResize([800], max_size=1333),
T.RandomResize([400], max_size=1333),
T.RandomResize([800], max_size=1333),
# T.RandomResize([400], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
......
images/out/result.jpg

1.35 MB | W: | H:

images/out/result.jpg

1.35 MB | W: | H:

images/out/result.jpg
images/out/result.jpg
images/out/result.jpg
images/out/result.jpg
  • 2-up
  • Swipe
  • Onion skin
export MIGRAPHX_ENABLE_MIOPEN_CONCAT=1
migraphx-driver perf --onnx \
../weights/ground_opt.onnx \
--fp16 \
--output \
../weights/ground_opt.mxr
\ No newline at end of file
MIGRAPHX_LOG=debug migraphx-driver compile \
--onnx weights/ground_external.onnx \
--gpu \
-p dead_code_elimination \
--output weights/ground.mgx
# -p eliminate_contiguous \
# -p simplify_reshapes \
# -p simplify_algebra \
# -p eliminate_identity \
# -p common_subexpression_elimination \
\ No newline at end of file
......@@ -57,7 +57,7 @@ def _mgx_shape_to_numpy(shape):
# 🚀 MIGraphX 推理类(带缓存)
# =========================
class MIGraphXModel:
def __init__(self, onnx_path, cache_path="weights/ground.mxr", force_recompile=False):
def __init__(self, onnx_path, cache_path="weights/ground_opt.mxr", force_recompile=False):
self.cache_path = cache_path
# ====== 优先加载缓存 ======
......@@ -228,10 +228,10 @@ def benchmark(model, tokenizer, image, caption, box_th, text_th, warmup=5, runs=
# =========================
if __name__ == "__main__":
model_path = "weights/ground_simplified.onnx"
cache_path = "weights/ground_simplified.mxr" # ⭐ 缓存文件
model_path = "../weights/ground_opt.onnx"
cache_path = "../weights/ground_opt.mxr" # ⭐ 缓存文件
img_path = "images/in/car_1.jpg"
img_path = "../images/in/car_1.jpg"
TEXT_PROMPT = "car ."
BOX_TRESHOLD = 0.35
......
import cv2
import numpy as np
import time
import os
import migraphx
from typing import Tuple
import torch
import groundingdino.datasets.transforms as T
from PIL import Image
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
# T.RandomResize([400], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def _mgx_shape_to_numpy(shape):
shape_str = str(shape)
if "int64_type" in shape_str:
dtype = np.int64
elif "bool_type" in shape_str:
dtype = np.bool_
elif "half_type" in shape_str:
dtype = np.float16
else:
dtype = np.float32
try:
dims = list(shape.dims())
except Exception:
dims = []
try:
lens = list(shape.lens())
except Exception:
lens = []
return dtype, (dims if len(dims) > 0 else lens)
# =========================
# 🚀 MIGraphX 推理类(带缓存与生命周期管理)
# =========================
class MIGraphXModel:
def __init__(self, onnx_path, cache_path="weights/ground_opt.mxr", force_recompile=False, device_id=0):
self.cache_path = cache_path
if os.path.exists(cache_path) and not force_recompile:
print(f"⚡ 直接加载已编译模型: {cache_path}")
self.model = migraphx.load(cache_path)
else:
print("🔍 从 ONNX 构建 MIGraphX")
self.model = migraphx.parse_onnx(onnx_path)
print(f"⚙️ 编译 MIGraphX(GPU {device_id})")
self.model.compile(t=migraphx.get_target("gpu"), device_id=device_id)
print(f"💾 保存编译模型到: {cache_path}")
migraphx.save(self.model, cache_path)
self.input_shapes = self.model.get_inputs()
def infer(self, input_dict):
mgx_inputs = {}
# 【关键修复区】:用于保持 NumPy 数组存活,防止 Python 垃圾回收导致底层指针失效
self._keep_alive_cache = {}
provided_names = set(input_dict.keys())
required_names = {
k for k in self.input_shapes.keys()
if not str(k).startswith("main:#output")
}
for name in required_names:
shape = self.input_shapes[name]
target_dtype, lens = _mgx_shape_to_numpy(shape)
if name in provided_names:
# 1. 必须转为连续内存!防止 PyTorch 转过来的 array 内存步长不一致
arr = np.ascontiguousarray(input_dict[name])
# 2. 强制类型转换
if arr.dtype != target_dtype:
arr = arr.astype(target_dtype)
else:
# 缺失的输入用 0 补齐
arr = np.zeros(lens, dtype=target_dtype)
# 3. 将数组塞进字典,强行续命!
self._keep_alive_cache[name] = arr
# 4. 安全地将指针移交给 migraphx
mgx_inputs[name] = migraphx.argument(arr)
start = time.time()
result = self.model.run(mgx_inputs)
infer_time = time.time() - start
outputs = [np.array(r) for r in result]
# 推理结束,释放内存
self._keep_alive_cache.clear()
return outputs, infer_time
# =========================
# 推理函数 (硬编码输入,无 Tokenizer)
# =========================
def predict(model, image, box_threshold, is_benchmark=False):
input_dict = {
"img": np.expand_dims(np.asarray(image), axis=0),
"position_ids": np.array([[0, 0, 1, 0]]),
"input_ids": np.array([[101, 2482, 1012, 102]]),
"token_type_ids": np.array([[0, 0, 0, 0]]),
"text_token_mask": np.array([[
[True, False, False, False],
[False, True, True, False],
[False, True, True, False],
[False, False, False, True]
]]),
"attention_mask": np.array([[True, True, True, True]])
}
outputs, infer_time = model.infer(input_dict)
if not is_benchmark:
print(f"Inference time: {infer_time*1000:.2f} ms")
logits = sigmoid(outputs[0][0])
boxes = outputs[1][0]
max_values = np.max(logits, axis=1)
mask = max_values > box_threshold
logits = logits[mask]
boxes = boxes[mask]
phrases = ["car"] * len(boxes)
return boxes, np.max(logits, axis=1), phrases
# =========================
# Benchmark
# =========================
def benchmark(model, image, box_th, warmup=5, runs=10):
print("\n🔥 预热")
for _ in range(warmup):
predict(model, image, box_th, True)
print("\n🚀 测试")
times = []
for i in range(runs):
start = time.time()
predict(model, image, box_th, True)
times.append(time.time() - start)
print(f"\n平均耗时: {np.mean(times)*1000:.2f} ms")
print(f"FPS: {1/np.mean(times):.2f}")
# =========================
# 主函数
# =========================
# if __name__ == "__main__":
# model_path = "../weights/ground_opt.onnx"
# cache_path = "../weights/ground_opt.mxr"
# img_path = "../images/in/car_1.jpg"
# BOX_TRESHOLD = 0.35
# DEVICE_ID = 5 # 匹配你之前报错堆栈里的 device: 5 / 0 的情况,按需修改
# model = MIGraphXModel(
# model_path,
# cache_path=cache_path,
# force_recompile=False,
# device_id=DEVICE_ID
# )
# image_source, image = load_image(img_path)
# benchmark(model, image, BOX_TRESHOLD)
# boxes, confs, phrases = predict(model, image, BOX_TRESHOLD)
# print("检测结果:", phrases)
def test_like_perf(model):
print("\n" + "="*60)
print("🛠️ 模拟 perf 工具:生成完美对齐的 Dummy 数据测试")
print("="*60)
mgx_inputs = {}
keep_alive_cache = [] # 强行续命池
# 1. 严格按照模型要求的形状造假数据
for name, shape in model.get_inputs().items():
if str(name).startswith("main:#output"):
continue
# 解析真实需要的类型和形状
target_dtype, lens = _mgx_shape_to_numpy(shape)
print(f" 📦 分配 {name}: shape={lens}, dtype={target_dtype.__name__}")
# 生成分毫不差的全零矩阵(完美模拟 migraphx-driver)
dummy_data = np.zeros(lens, dtype=target_dtype)
keep_alive_cache.append(dummy_data)
# 移交指针
mgx_inputs[name] = migraphx.argument(dummy_data)
print("\n🚀 开始 Dummy 推理测试...")
try:
start = time.time()
model.run(mgx_inputs)
print(f"✅ Python 端 Dummy 推理成功!没有任何 VMFault!耗时: {(time.time()-start)*1000:.2f}ms")
except Exception as e:
print(f"❌ 依然报错: {e}")
# ------------------
# 在主函数里这样调用:
# ------------------
if __name__ == "__main__":
model_path = "../weights/ground_opt.onnx"
cache_path = "../weights/ground_opt.mxr"
model = migraphx.load(cache_path) # 直接加载你确定没问题的 mxr
# 运行模拟测试
test_like_perf(model)
\ No newline at end of file
migraphx-driver perf --batch 1 \
-n 10 \
--fp16 \
--migraphx ../weights/ground_opt.mxr
\ No newline at end of file
import sys
import numpy as np
from onnx_modifier import ONNXModifier
def change_inf_to_value(om: ONNXModifier):
records = set()
for where_node in om.get_nodes("Where"):
for input_name in where_node.inputs[1:]:
init = om.get_initializer(input_name)
if init is None:
continue
assert input_name == init.name
init_name = input_name
if init_name in records:
continue
# info = np.finfo(np.float32)
info = np.finfo(np.float16)
data = om.get_initializer_value(init.name)
if data.size > 1:
continue
if data == np.inf:
om.set_initializer_value(init_name, np.array(info.max, dtype=np.float32))
elif data == -np.inf:
om.set_initializer_value(init_name, np.array(info.min, dtype=np.float32))
else:
continue
# print("Changed value:", init_name)
records.add(init_name)
def optimize_where_ndoes(om: ONNXModifier):
"""Where节点等价替换
(1) condition为initializer, X为0, Y为输入数据:
Where(cond, X, Y) ==> Mul(Y, ~cond)
(2) condition为initializer, X为负无穷, Y为输入数据
Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
(3) condition为真实输入, X为负无穷, Y为输入数据
Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
cases:
1. Where(cond, -inf, input)
a. /transformer/encoder/fusion_layers.*/attn/Where
b. /transformer/encoder/fusion_layers.*/attn/Where_1
c. /class_embed.0_*/Where: Where(cond, -inf, input)
2. Where(cond, 0, input):
a. /transformer/encoder/layers.*/self_attn/Where
b. /transformer/decoder/layers.*/cross_attn/Where
"""
for where_node in om.get_nodes("Where"):
where_name = where_node.name
# print("Process where node:", where_name)
x_value = om.get_initializer_value(where_node.inputs[1])
assert x_value.size == 1
assert x_value == np.array(0.0, dtype=np.float32) or \
x_value == np.array(-np.inf, dtype=np.float32)
cond_init = om.get_initializer(where_node.inputs[0])
if cond_init is not None:
cond_value = om.get_initializer_value(where_node.inputs[0])
if x_value == np.array(0.0, dtype=np.float32):
# Where(cond, X, Y) ==> Mul(Y, ~cond)
mul_name = where_name.replace("Where", "NewMul")
mul_b_init = om.create_initializer(mul_name + "_B",
(~cond_value).astype(np.float32))
mul_node = om.create_node("Mul",
mul_name,
[where_node.inputs[2], mul_b_init.name],
[mul_name+"_output_0"],
index=where_node.index)
next_nodes = where_node.next_nodes
for next_node in next_nodes:
next_node.replace_input(where_node.outputs[0], mul_node.outputs[0])
elif x_value == np.array(-np.inf, dtype=np.float32):
# Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
sub_name = where_name.replace("Where", "NewSub")
sub_b_init = om.create_initializer(
sub_name + "_B",
np.where(cond_value.astype(np.float32),
np.finfo(np.float16).max, 0.0).astype(np.float32)
)
sub_node = om.create_node("Sub",
sub_name,
[where_node.inputs[2], sub_b_init.name],
[sub_name+"_output_0"],
index=where_node.index)
next_nodes = where_node.next_nodes
for next_node in next_nodes:
next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
else:
# Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
assert x_value == np.array(-np.inf, dtype=np.float32)
cast_name = where_name.replace("Where", "NewCast")
mul_name = where_name.replace("Where", "NewMul")
sub_name = where_name.replace("Where", "NewSub")
cast_node = om.create_node("Cast",
cast_name,
[where_node.inputs[0]],
[cast_name+"_output_0"],
to=1,
index=where_node.index)
mul_b_init = om.create_initializer(mul_name + "_B",
np.array([np.finfo(np.float16).max], np.float32))
mul_node = om.create_node("Mul",
mul_name,
[cast_node.outputs[0], mul_b_init.name],
[mul_name+"_output_0"],
index=cast_node.index+1)
sub_node = om.create_node("Sub",
sub_name,
[where_node.inputs[2], mul_node.outputs[0]],
[sub_name+"_output_0"],
index=mul_node.index+1)
next_nodes = where_node.next_nodes
for next_node in next_nodes:
next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
om.update_map()
def optimize_transpose_nodes(om: ONNXModifier):
transpose_list = [
"/transformer/encoder/Transpose",
"/transformer/encoder/Transpose_1",
"/transformer/encoder/Transpose_2",
"/transformer/encoder/Transpose_3",
"/transformer/encoder/Transpose_4",
"/transformer/encoder/Transpose_5",
"/transformer/encoder/Transpose_6",
"/transformer/encoder/Transpose_7",
"/transformer/encoder/Transpose_8",
"/transformer/encoder/Transpose_9",
"/transformer/encoder/Transpose_10",
"/transformer/encoder/Transpose_11",
"/transformer/decoder/layers.0/Transpose",
"/transformer/decoder/layers.0/Transpose_1",
"/transformer/decoder/layers.0/Transpose_2",
"/transformer/decoder/layers.1/Transpose",
"/transformer/decoder/layers.1/Transpose_1",
"/transformer/decoder/layers.1/Transpose_2",
"/transformer/decoder/layers.2/Transpose",
"/transformer/decoder/layers.2/Transpose_1",
"/transformer/decoder/layers.2/Transpose_2",
"/transformer/decoder/layers.3/Transpose",
"/transformer/decoder/layers.3/Transpose_1",
"/transformer/decoder/layers.3/Transpose_2",
"/transformer/decoder/layers.4/Transpose",
"/transformer/decoder/layers.4/Transpose_1",
"/transformer/decoder/layers.4/Transpose_2",
"/transformer/decoder/layers.5/Transpose",
"/transformer/decoder/layers.5/Transpose_1",
"/transformer/decoder/layers.5/Transpose_2",
"/transformer/Transpose_8",
"/transformer/decoder/Transpose",
"/transformer/decoder/Transpose_1",
"/transformer/decoder/Transpose_2",
"/transformer/decoder/Transpose_3",
"/transformer/decoder/Transpose_4",
"/transformer/decoder/Transpose_5",
"/transformer/decoder/Transpose_6",
"/transformer/decoder/Transpose_7",
"/transformer/decoder/Transpose_8",
"/transformer/decoder/Transpose_9",
"/transformer/decoder/Transpose_10",
"/transformer/decoder/Transpose_11"
]
for name in transpose_list:
node = om.get_node(name)
assert node.attrs['perm'] == [1, 0 , 2] or node.attrs['perm'] == [1, 0 , 2, 3], \
f"perm={node.attrs['perm']}"
next_nodes = om.get_next_nodes(node)
for node_ in next_nodes:
node_.replace_input(node.outputs[0], node.inputs[0])
# modify /transformer/encoder/text_layers.*/self_attn/Reshape_4
# om.set_initializer_value("_v_8735", np.array([-1, 4, 256], np.int64))
shape_init1 = om.create_initializer(
"/transformer/encoder/text_layers.x/self_attn/des_shape",
np.array([1, 4, 256], np.int64)
)
for i in range(6):
reshape_node = om.get_node(f"/transformer/encoder/text_layers.{i}/self_attn/Reshape_4")
reshape_node.set_input(1, shape_init1.name)
# modify /transformer/enc_out_class_embed/Transpose
om.get_node("/transformer/enc_out_class_embed/Transpose").set_attribute("perm", [0, 2, 1])
# modify /transformer/decoder/Reshape_*
om.set_initializer_value("_v_5525", np.array([1, 900, -1], np.int64))
# modify /transformer/decoder/layers.*/self_attn/Reshape_4
# modify /transformer/decoder/layers.*/ca_text/Reshape_6
# om.set_initializer_value("_v_6230", np.array([-1, 900, 256], np.int64))
shape_init3 = om.create_initializer(
"/transformer/decoder/layers.x/self_attn_ca_text/des_shape",
np.array([1, 900, 256], np.int64)
)
for i in range(6):
reshape_node1 = om.get_node(f"/transformer/decoder/layers.{i}/self_attn/Reshape_4")
reshape_node1.set_input(1, shape_init3.name)
reshape_node2 = om.get_node(f"/transformer/decoder/layers.{i}/ca_text/Reshape_6")
reshape_node2.set_input(1, shape_init3.name)
# modify /transformer/decoder/layers.0/Add
# modify /transformer/decoder/layers.0/Add_1
init_name = "/transformer/Tile_1_output_0"
add_value = om.get_initializer_value(init_name)
om.set_initializer_value(init_name, np.ascontiguousarray(add_value.transpose(1, 0, 2)))
om.update_map()
om.infer_shape()
def optmize_sin_cos_block(om: ONNXModifier):
node_pairs = [
("/transformer/decoder/Gather_1", "/transformer/decoder/ref_point_head/layers.0/MatMul"),
("/transformer/decoder/Gather_6", "/transformer/decoder/ref_point_head/layers.0_1/MatMul"),
("/transformer/decoder/Gather_11", "/transformer/decoder/ref_point_head/layers.0_2/MatMul"),
("/transformer/decoder/Gather_16", "/transformer/decoder/ref_point_head/layers.0_3/MatMul"),
("/transformer/decoder/Gather_21", "/transformer/decoder/ref_point_head/layers.0_4/MatMul"),
("/transformer/decoder/Gather_26", "/transformer/decoder/ref_point_head/layers.0_5/MatMul"),
]
unsqueeze_axes_init1 = om.create_initializer(
"/transformer/decoder/sin_cos_block/unsqueeze_axes1",
np.array([3, 4], np.int64)
)
slice_axes_init = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_axes",
np.array([4], np.int64)
)
slice_steps_init = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_steps",
np.array([1], np.int64)
)
slice_starts_init1 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_starts1",
np.array([0], np.int64)
)
slice_ends_init1 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_ends1",
np.array([1], np.int64)
)
slice_starts_init2 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_steps2",
np.array([1], np.int64)
)
slice_ends_init2 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_ends2",
np.array([2], np.int64)
)
reshape_init = om.create_initializer(
"/transformer/decoder/sin_cos_block/reshape_dst_shape",
np.array([1, 900, -1], np.int64)
)
for i, (gather_name, matmul_name) in enumerate(node_pairs):
gather_node = om.get_node(gather_name)
next_node = om.get_next_nodes(gather_node)[0]
assert next_node.op_type == "Mul", f"{next_node.op_type} {next_node.name}"
mul_init_value = om.get_initializer_value(next_node.inputs[1])
assert mul_init_value.size == 1
next_node = om.get_next_nodes(next_node)[0]
assert next_node.op_type == "Unsqueeze"
next_node.set_inputs([gather_node.inputs[0], unsqueeze_axes_init1.name])
next_node = om.get_next_nodes(next_node)[0]
assert next_node.op_type == "Div"
div_init_value = om.get_initializer_value(next_node.inputs[1])
new_value = (div_init_value / mul_init_value).reshape(1, 1, 1, 64, 2)
new_init = om.create_initializer(next_node.name + "_B", new_value)
next_node.set_input(1, new_init.name)
next_nodes = om.get_next_nodes(next_node)
assert len(next_nodes) == 2 and all(x.op_type == 'Slice' for x in next_nodes)
sin_node, cos_node = None, None
for j, slice_node in enumerate(next_nodes):
slice_node.set_inputs([slice_node.inputs[0],
slice_starts_init1.name if j == 0 else slice_starts_init2.name,
slice_ends_init1.name if j == 0 else slice_ends_init2.name,
slice_axes_init.name,
slice_steps_init.name])
next_node = om.get_next_nodes(slice_node)[0]
if next_node.op_type == "Sin":
sin_node = next_node
elif next_node.op_type == "Cos":
cos_node = next_node
else:
raise RuntimeError("match fail!")
next_node = om.get_next_nodes(next_node)[0]
assert next_node.op_type == "Unsqueeze"
next_node = om.get_next_nodes(next_node)[0]
assert next_node.op_type == "Concat"
next_node.set_inputs([sin_node.outputs[0], cos_node.outputs[0]])
next_node.set_attribute("axis", 4)
next_node = om.get_next_nodes(next_node)[0]
assert next_node.op_type == "Reshape"
next_node.set_input(1, reshape_init.name)
matmul_node = om.get_node(matmul_name)
matmul_node.set_input(0, next_node.outputs[0])
if i == 0:
mm_b_value = om.get_initializer_value(matmul_node.inputs[1])
mm_b_value = np.concatenate([mm_b_value[128:256, ...],
mm_b_value[0:128, ...],
mm_b_value[256:, ...]],
axis=0)
om.set_initializer_value(matmul_node.inputs[1], mm_b_value)
om.update_map()
om.infer_shape()
def fuse_one_attention(om: ONNXModifier, softmax_name: str, new_mask: bool = None, num_heads: int = 12):
softmax_node = om.get_node(softmax_name)
tmp_node = om.get_prev_nodes(softmax_node)[0]
assert tmp_node.op_type in ["MatMul", "Add"]
mask = None
if tmp_node.op_type == "Add":
mask_node = tmp_node
tmp_node = om.get_from_node(mask_node.inputs[0])
if tmp_node.op_type == "Div":
tmp_node = om.get_from_node(tmp_node.inputs[0])
assert tmp_node.op_type == "MatMul"
mask = mask_node.inputs[1]
assert new_mask is not None
tmp_node1 = om.get_from_node(tmp_node.inputs[0])
if tmp_node1.op_type == "Mul":
tmp_node1 = om.get_prev_nodes(tmp_node1)[0]
tmp_node2 = om.get_from_node(tmp_node.inputs[1])
assert tmp_node1.op_type == tmp_node2.op_type == "Transpose"
tmp_node1 = om.get_prev_nodes(tmp_node1)[0]
tmp_node2 = om.get_prev_nodes(tmp_node2)[0]
assert tmp_node1.op_type == tmp_node2.op_type == "Reshape"
q, k = tmp_node1.inputs[0], tmp_node2.inputs[0]
tmp_node = om.get_next_nodes(softmax_node)[0]
assert tmp_node.op_type == "MatMul"
tmp_node3 = om.get_from_node(tmp_node.inputs[1])
if tmp_node3 is not None:
assert tmp_node3.op_type == "Transpose"
tmp_node3 = om.get_prev_nodes(tmp_node3)[0]
assert tmp_node3.op_type == "Reshape"
v = tmp_node3.inputs[0]
else:
v_init = om.get_initializer(tmp_node.inputs[1])
v_init_value = om.get_initializer_value(tmp_node.inputs[1])
v_init_value = v_init_value[None, ...].transpose(0, 2, 1, 3)
B, S, H, D = v_init_value.shape
v_init_value = np.ascontiguousarray(v_init_value.reshape(B, S, H*D))
om.set_initializer_value(tmp_node.inputs[1], v_init_value)
v = v_init.name
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
mha_next_node = om.get_next_nodes(tmp_node)[0]
if mha_next_node.op_type == "Gemm":
gemm_next_node = om.get_next_nodes(mha_next_node)[0]
assert gemm_next_node.op_type == "Reshape"
reshape_next_node = om.get_next_nodes(gemm_next_node)[0]
assert reshape_next_node.op_type == "Add"
else:
assert mha_next_node.op_type == "MatMul"
name_prefix = '/'.join(softmax_name.split('/')[:-1])
mha_name = f"{name_prefix}/MultiHeadAttention"
mha_node = om.create_node("MultiHeadAttention",
mha_name,
[q, k, v] if mask is None else [q, k, v, new_mask],
[mha_name+'_output_0'],
num_heads=num_heads,
domain="com.microsoft",
index=mha_next_node.index-1)
mha_next_node.replace_input(mha_next_node.inputs[0], mha_node.outputs[0])
if mha_next_node.op_type == "Gemm":
weights = om.get_initializer_value(mha_next_node.inputs[1])
transB = mha_next_node.attrs["transB"]
assert transB == 1
weights = np.ascontiguousarray(weights.transpose(1, 0))
om.set_initializer_value(mha_next_node.inputs[1], weights)
new_matmul_name = mha_next_node.name.replace("Gemm", "MatMul(Gemm)")
new_matmul_node = om.create_node("MatMul",
new_matmul_name,
[mha_node.outputs[0], mha_next_node.inputs[1]],
[new_matmul_name + "_output_0"],
index=mha_next_node.index)
new_bias_name = mha_next_node.name.replace("Gemm", "Add(Gemm)")
new_add_node = om.create_node("Add",
new_bias_name,
[new_matmul_node.outputs[0], mha_next_node.inputs[2]],
[new_bias_name + "_output_0"],
index=new_matmul_node.index+1)
reshape_next_node.replace_input(gemm_next_node.outputs[0], new_add_node.outputs[0])
def optimize_normal_attention(om: ONNXModifier):
def create_new_attention_mask():
mask_next_node = om.get_to_nodes("attention_mask")[0]
cast_node = om.create_node("Cast",
"Cast_for_attention_mask",
["attention_mask"],
["Cast_for_attention_mask_output_0"],
to=1,
index=mask_next_node.index)
reducesum_node = om.create_node("ReduceSum",
"ReduceSum_for_mask",
[cast_node.outputs[0]],
["ReduceSum_for_mask_output_0"],
axes=1,
keepdims=0,
index=cast_node.index+1)
return reducesum_node.outputs[0]
# bert
# for i in range(12):
# fuse_one_attention(om, f"/bert/encoder/layer.{i}/attention/self/Softmax", "text_token_mask", num_heads=12)
new_mask = create_new_attention_mask()
for i in range(6):
# /transformer/encoder
# fuse_one_attention(om, f"/transformer/encoder/text_layers.{i}/self_attn/Softmax", "text_token_mask", num_heads=4)
# /transformer/decoder
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/self_attn/Softmax", new_mask, num_heads=8)
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/ca_text/Softmax", new_mask, num_heads=8)
om.update_map()
def optimize_backbone_attention(om: ONNXModifier):
def get_original_mask(mask_name, name_prefix):
mask_value = om.get_initializer_value(mask_name)
orig_mask = np.where(mask_value==0, 1, 0).astype(np.bool_)
orig_mask_init = om.create_initializer(f"{name_prefix}/mask", orig_mask)
return orig_mask_init.name
def _fuse_one_attention(softmax_name: str):
name_prefix = '/'.join(softmax_name.split('/')[:-1])
softmax_node = om.get_node(softmax_name)
tmp_node = om.get_prev_nodes(softmax_node)[0]
pos_bias_init = None
if tmp_node.op_type == "Reshape":
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Add"
pos_bias_init = om.get_initializer(tmp_node.inputs[1])
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Add"
mask = get_original_mask(tmp_node.inputs[1], name_prefix)
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "MatMul"
qk_matmul = tmp_node
tmp_node = om.get_from_node(qk_matmul.inputs[0])
assert tmp_node.op_type == "Mul"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Gather"
q_gather_node = tmp_node
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
reshape_node = tmp_node
tmp_node = om.get_from_node(qk_matmul.inputs[1])
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Gather"
k_gather_node = tmp_node
tmp_node = om.get_next_nodes(softmax_node)[0]
assert tmp_node.op_type == "MatMul"
v_gather_node = om.get_from_node(tmp_node.inputs[1])
assert v_gather_node.op_type == "Gather"
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
mha_out = tmp_node.outputs[0]
old_dst_shape = om.get_initializer_value(reshape_node.inputs[1])
b, s, _, h, d = old_dst_shape
new_dst_shape = [b, s, _, h*d]
new_dst_shape_init = om.create_initializer(f"{name_prefix}/qkv_hidden_states_shape",
np.array(new_dst_shape, np.int64))
reshape_node.set_input(1, new_dst_shape_init.name)
for node in [q_gather_node, k_gather_node, v_gather_node]:
node.set_input(0, reshape_node.outputs[0])
node.set_attribute("axis", 2)
mha_name = f"{name_prefix}/MultiHeadAttention"
inputs = [q_gather_node.outputs[0],
k_gather_node.outputs[0],
v_gather_node.outputs[0],
mask]
if pos_bias_init is not None:
inputs.append(pos_bias_init.name)
mha_node = om.create_node("MultiHeadAttention",
mha_name,
inputs,
[mha_name+'_output_0'],
num_heads=h,
domain="com.microsoft",
index=softmax_node.index)
mha_next_node = om.get_to_nodes(mha_out)[0]
mha_next_node.replace_input(mha_out, mha_node.outputs[0])
num_layers = 4
for l in range(num_layers):
num_blocks = 18 if l == 2 else 2
for b in range(num_blocks):
_fuse_one_attention(f"/backbone/backbone.0/layers.{l}/blocks.{b}/attn/softmax/Softmax")
def optimize_ms_deform_attn(om: ONNXModifier):
def fuse_ms_deform_attn(value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, output):
value_next_node = om.get_to_nodes(value)[0]
index = value_next_node.index
name_prefix = '/'.join(value.split('/')[:-1])
node_name = f"{name_prefix}/MSDeformAttn"
fusion_node = om.create_node("MSDeformAttn",
node_name,
[value, spatial_shapes, level_start_index,
sampling_locations, attention_weights],
[f"{node_name}_output_0"],
index=index)
next_nodes = om.get_to_nodes(output)
for node in next_nodes:
node.replace_input(output, fusion_node.outputs[0])
spatial_shapes_int = om.create_initializer(
"/transformer/spatial_shapes",
np.array([(100, 150), (50, 75), (25, 38), (13, 19)], dtype=np.int64)
)
level_start_index_init = om.create_initializer(
"/transformer/level_start_index",
np.array([0, 15000, 18750, 19700], dtype=np.int64)
)
for i in range(6):
fuse_ms_deform_attn(
f"/transformer/encoder/layers.{i}/self_attn/Reshape_output_0",
spatial_shapes_int.name,
level_start_index_init.name,
f"/transformer/encoder/layers.{i}/self_attn/Add_output_0",
f"/transformer/encoder/layers.{i}/self_attn/Reshape_3_output_0",
f"/transformer/encoder/layers.{i}/self_attn/Transpose_9_output_0"
)
fuse_ms_deform_attn(
f"/transformer/decoder/layers.{i}/cross_attn/Reshape_output_0",
spatial_shapes_int.name,
level_start_index_init.name,
f"/transformer/decoder/layers.{i}/cross_attn/Add_output_0",
f"/transformer/decoder/layers.{i}/cross_attn/Reshape_3_output_0",
f"/transformer/decoder/layers.{i}/cross_attn/Transpose_9_output_0"
)
om.update_map()
def optimize_bidirect_attention(om: ONNXModifier):
for i in range(6):
reduce_max_name = f"/transformer/encoder/fusion_layers.{i}/attn/ReduceMax_1"
reduce_max_node = om.get_node(reduce_max_name)
next_node = om.get_next_nodes(reduce_max_node)[0]
assert next_node.op_type == "Sub"
name_prefix = '/'.join(reduce_max_name.split('/')[:-1])
matmul_name = f"{name_prefix}/identity_MatMul"
matmul_init = om.create_initializer(matmul_name + "_B",
np.diag(np.array([1] * 1)).astype(np.float32))
matmul_node = om.create_node(
"MatMul",
matmul_name,
[reduce_max_node.outputs[0], matmul_init.name],
[f"{matmul_name}_output_0"],
index = reduce_max_node.index+1
)
next_node.set_input(1, matmul_node.outputs[0])
def main():
input_onnx_path = sys.argv[1]
output_onnx_path = sys.argv[2]
# input_onnx_path = "ground_sim.onnx"
# output_onnx_path = "ground_sim_0424_2nd.onnx"
om = ONNXModifier(input_onnx_path)
optimize_where_ndoes(om) # 1. 替换where节点
optimize_transpose_nodes(om) # 2. 优化transpose节点
optmize_sin_cos_block(om) # 3. 优化位置编码
# om.add_opset_import("com.microsoft", 1)
# optimize_normal_attention(om) # 4. 融合bert、transformer中的mha
# optimize_ms_deform_attn(om) # 5. 融合多尺度可变形注意力
# optimize_backbone_attention(om) # 6. 融合backbone中的注意力
optimize_bidirect_attention(om) # 7. 优化双向注意力
om.save(output_onnx_path, save_as_external_data=False)
if __name__ == "__main__":
main()
import sys
import numpy as np
from onnx_modifier import ONNXModifier
def change_inf_to_value(om: ONNXModifier):
records = set()
for where_node in om.get_nodes("Where"):
for input_name in where_node.inputs[1:]:
init = om.get_initializer(input_name)
if init is None:
continue
assert input_name == init.name
init_name = input_name
if init_name in records:
continue
# info = np.finfo(np.float32)
info = np.finfo(np.float16)
data = om.get_initializer_value(init.name)
if data.size > 1:
continue
if data == np.inf:
om.set_initializer_value(init_name, np.array(info.max, dtype=np.float32))
elif data == -np.inf:
om.set_initializer_value(init_name, np.array(info.min, dtype=np.float32))
else:
continue
# print("Changed value:", init_name)
records.add(init_name)
# def optimize_where_ndoes(om: ONNXModifier):
# """Where节点等价替换
# (1) condition为initializer, X为0, Y为输入数据:
# Where(cond, X, Y) ==> Mul(Y, ~cond)
# (2) condition为initializer, X为负无穷, Y为输入数据
# Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
# (3) condition为真实输入, X为负无穷, Y为输入数据
# Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
# cases:
# 1. Where(cond, -inf, input)
# a. /transformer/encoder/fusion_layers.*/attn/Where
# b. /transformer/encoder/fusion_layers.*/attn/Where_1
# c. /class_embed.0_*/Where: Where(cond, -inf, input)
# 2. Where(cond, 0, input):
# a. /transformer/encoder/layers.*/self_attn/Where
# b. /transformer/decoder/layers.*/cross_attn/Where
# """
# for where_node in om.get_nodes("Where"):
# where_name = where_node.name
# # print("Process where node:", where_name)
# x_value = om.get_initializer_value(where_node.inputs[1])
# assert x_value.size == 1
# assert x_value == np.array(0.0, dtype=np.float32) or \
# x_value == np.array(-np.inf, dtype=np.float32)
# cond_init = om.get_initializer(where_node.inputs[0])
# if cond_init is not None:
# cond_value = om.get_initializer_value(where_node.inputs[0])
# if x_value == np.array(0.0, dtype=np.float32):
# # Where(cond, X, Y) ==> Mul(Y, ~cond)
# mul_name = where_name.replace("Where", "NewMul")
# mul_b_init = om.create_initializer(mul_name + "_B",
# (~cond_value).astype(np.float32))
# mul_node = om.create_node("Mul",
# mul_name,
# [where_node.inputs[2], mul_b_init.name],
# [mul_name+"_output_0"],
# index=where_node.index)
# next_nodes = where_node.next_nodes
# for next_node in next_nodes:
# next_node.replace_input(where_node.outputs[0], mul_node.outputs[0])
# elif x_value == np.array(-np.inf, dtype=np.float32):
# # Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
# sub_name = where_name.replace("Where", "NewSub")
# sub_b_init = om.create_initializer(
# sub_name + "_B",
# np.where(cond_value.astype(np.float32),
# np.finfo(np.float16).max, 0.0).astype(np.float32)
# )
# sub_node = om.create_node("Sub",
# sub_name,
# [where_node.inputs[2], sub_b_init.name],
# [sub_name+"_output_0"],
# index=where_node.index)
# next_nodes = where_node.next_nodes
# for next_node in next_nodes:
# next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
# else:
# # Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
# assert x_value == np.array(-np.inf, dtype=np.float32)
# cast_name = where_name.replace("Where", "NewCast")
# mul_name = where_name.replace("Where", "NewMul")
# sub_name = where_name.replace("Where", "NewSub")
# cast_node = om.create_node("Cast",
# cast_name,
# [where_node.inputs[0]],
# [cast_name+"_output_0"],
# to=1,
# index=where_node.index)
# mul_b_init = om.create_initializer(mul_name + "_B",
# np.array([np.finfo(np.float16).max], np.float32))
# mul_node = om.create_node("Mul",
# mul_name,
# [cast_node.outputs[0], mul_b_init.name],
# [mul_name+"_output_0"],
# index=cast_node.index+1)
# sub_node = om.create_node("Sub",
# sub_name,
# [where_node.inputs[2], mul_node.outputs[0]],
# [sub_name+"_output_0"],
# index=mul_node.index+1)
# next_nodes = where_node.next_nodes
# for next_node in next_nodes:
# next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
# om.update_map()
def optimize_where_ndoes(om: ONNXModifier):
"""Where节点等价替换 (加入安全校验版本)"""
for where_node in om.get_nodes("Where"):
where_name = where_node.name
# 1. 安全获取 X 的值,如果 X 不是常量(initializer),直接跳过不优化
x_init = om.get_initializer(where_node.inputs[1])
if x_init is None:
continue
x_value = om.get_initializer_value(where_node.inputs[1])
# 2. 避免 assert 崩溃:如果 size 不为 1,说明不是我们要找的 Attention Mask 节点,跳过
if x_value.size != 1:
continue
# 3. 判断是否符合优化条件(0.0 或 -inf),不符合直接跳过
is_zero = (x_value == np.array(0.0, dtype=np.float32))
is_neg_inf = (x_value == np.array(-np.inf, dtype=np.float32))
if not (is_zero or is_neg_inf):
continue
cond_init = om.get_initializer(where_node.inputs[0])
if cond_init is not None:
cond_value = om.get_initializer_value(where_node.inputs[0])
if is_zero:
# Where(cond, X, Y) ==> Mul(Y, ~cond)
mul_name = where_name.replace("Where", "NewMul")
mul_b_init = om.create_initializer(mul_name + "_B",
(~cond_value).astype(np.float32))
mul_node = om.create_node("Mul",
mul_name,
[where_node.inputs[2], mul_b_init.name],
[mul_name+"_output_0"],
index=where_node.index)
next_nodes = where_node.next_nodes
for next_node in next_nodes:
next_node.replace_input(where_node.outputs[0], mul_node.outputs[0])
elif is_neg_inf:
# Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
sub_name = where_name.replace("Where", "NewSub")
sub_b_init = om.create_initializer(
sub_name + "_B",
np.where(cond_value.astype(np.float32),
np.finfo(np.float16).max, 0.0).astype(np.float32)
)
sub_node = om.create_node("Sub",
sub_name,
[where_node.inputs[2], sub_b_init.name],
[sub_name+"_output_0"],
index=where_node.index)
next_nodes = where_node.next_nodes
for next_node in next_nodes:
next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
else:
# Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
# 当 condition 不是 initializer 时,只处理 -inf 的情况
if not is_neg_inf:
continue
cast_name = where_name.replace("Where", "NewCast")
mul_name = where_name.replace("Where", "NewMul")
sub_name = where_name.replace("Where", "NewSub")
cast_node = om.create_node("Cast",
cast_name,
[where_node.inputs[0]],
[cast_name+"_output_0"],
to=1,
index=where_node.index)
mul_b_init = om.create_initializer(mul_name + "_B",
np.array([np.finfo(np.float16).max], np.float32))
mul_node = om.create_node("Mul",
mul_name,
[cast_node.outputs[0], mul_b_init.name],
[mul_name+"_output_0"],
index=cast_node.index+1)
sub_node = om.create_node("Sub",
sub_name,
[where_node.inputs[2], mul_node.outputs[0]],
[sub_name+"_output_0"],
index=mul_node.index+1)
next_nodes = where_node.next_nodes
for next_node in next_nodes:
next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
om.update_map()
def optimize_transpose_nodes(om: ONNXModifier):
transpose_list = [
"/transformer/encoder/Transpose",
"/transformer/encoder/Transpose_1",
"/transformer/encoder/Transpose_2",
"/transformer/encoder/Transpose_3",
"/transformer/encoder/Transpose_4",
"/transformer/encoder/Transpose_5",
"/transformer/encoder/Transpose_6",
"/transformer/encoder/Transpose_7",
"/transformer/encoder/Transpose_8",
"/transformer/encoder/Transpose_9",
"/transformer/encoder/Transpose_10",
"/transformer/encoder/Transpose_11",
"/transformer/decoder/layers.0/Transpose",
"/transformer/decoder/layers.0/Transpose_1",
"/transformer/decoder/layers.0/Transpose_2",
"/transformer/decoder/layers.1/Transpose",
"/transformer/decoder/layers.1/Transpose_1",
"/transformer/decoder/layers.1/Transpose_2",
"/transformer/decoder/layers.2/Transpose",
"/transformer/decoder/layers.2/Transpose_1",
"/transformer/decoder/layers.2/Transpose_2",
"/transformer/decoder/layers.3/Transpose",
"/transformer/decoder/layers.3/Transpose_1",
"/transformer/decoder/layers.3/Transpose_2",
"/transformer/decoder/layers.4/Transpose",
"/transformer/decoder/layers.4/Transpose_1",
"/transformer/decoder/layers.4/Transpose_2",
"/transformer/decoder/layers.5/Transpose",
"/transformer/decoder/layers.5/Transpose_1",
"/transformer/decoder/layers.5/Transpose_2",
"/transformer/Transpose_8",
"/transformer/decoder/Transpose",
"/transformer/decoder/Transpose_1",
"/transformer/decoder/Transpose_2",
"/transformer/decoder/Transpose_3",
"/transformer/decoder/Transpose_4",
"/transformer/decoder/Transpose_5",
"/transformer/decoder/Transpose_6",
"/transformer/decoder/Transpose_7",
"/transformer/decoder/Transpose_8",
"/transformer/decoder/Transpose_9",
"/transformer/decoder/Transpose_10",
"/transformer/decoder/Transpose_11"
]
for name in transpose_list:
node = om.get_node(name)
# 安全校验:如果找不到这个节点,说明当前模型不需要优化这个点,跳过
if node is None:
continue
if 'perm' in node.attrs and (node.attrs['perm'] == [1, 0 , 2] or node.attrs['perm'] == [1, 0 , 2, 3]):
next_nodes = om.get_next_nodes(node)
for node_ in next_nodes:
node_.replace_input(node.outputs[0], node.inputs[0])
# modify /transformer/encoder/text_layers.*/self_attn/Reshape_4
shape_init1 = om.create_initializer(
"/transformer/encoder/text_layers.x/self_attn/des_shape",
np.array([1, 4, 256], np.int64)
)
for i in range(6):
reshape_node = om.get_node(f"/transformer/encoder/text_layers.{i}/self_attn/Reshape_4")
if reshape_node is not None:
reshape_node.set_input(1, shape_init1.name)
# modify /transformer/enc_out_class_embed/Transpose
trans_node = om.get_node("/transformer/enc_out_class_embed/Transpose")
if trans_node is not None:
trans_node.set_attribute("perm", [0, 2, 1])
# modify /transformer/decoder/Reshape_* # 安全校验:避免写死的随机变量名 _v_5525 引发崩溃
init_5525 = om.get_initializer("_v_5525")
if init_5525 is not None:
om.set_initializer_value("_v_5525", np.array([1, 900, -1], np.int64))
# modify /transformer/decoder/layers.*/self_attn/Reshape_4
# modify /transformer/decoder/layers.*/ca_text/Reshape_6
shape_init3 = om.create_initializer(
"/transformer/decoder/layers.x/self_attn_ca_text/des_shape",
np.array([1, 900, 256], np.int64)
)
for i in range(6):
reshape_node1 = om.get_node(f"/transformer/decoder/layers.{i}/self_attn/Reshape_4")
if reshape_node1 is not None:
reshape_node1.set_input(1, shape_init3.name)
reshape_node2 = om.get_node(f"/transformer/decoder/layers.{i}/ca_text/Reshape_6")
if reshape_node2 is not None:
reshape_node2.set_input(1, shape_init3.name)
# modify /transformer/decoder/layers.0/Add
# modify /transformer/decoder/layers.0/Add_1
init_name = "/transformer/Tile_1_output_0"
tile_init = om.get_initializer(init_name)
if tile_init is not None:
add_value = om.get_initializer_value(init_name)
om.set_initializer_value(init_name, np.ascontiguousarray(add_value.transpose(1, 0, 2)))
om.update_map()
# 将形状推断包起来,防止自定义算子(MSDeformAttn)导致推理失败崩溃
try:
om.infer_shape(strict_mode=False)
except Exception as e:
print(f"[Warning] infer_shape 跳过 (可能由于自定义算子引起). 详细信息: {e}")
def optmize_sin_cos_block(om: ONNXModifier):
node_pairs = [
("/transformer/decoder/Gather_1", "/transformer/decoder/ref_point_head/layers.0/MatMul"),
("/transformer/decoder/Gather_6", "/transformer/decoder/ref_point_head/layers.0_1/MatMul"),
("/transformer/decoder/Gather_11", "/transformer/decoder/ref_point_head/layers.0_2/MatMul"),
("/transformer/decoder/Gather_16", "/transformer/decoder/ref_point_head/layers.0_3/MatMul"),
("/transformer/decoder/Gather_21", "/transformer/decoder/ref_point_head/layers.0_4/MatMul"),
("/transformer/decoder/Gather_26", "/transformer/decoder/ref_point_head/layers.0_5/MatMul"),
]
# 提前创建一些公用的 initializer
unsqueeze_axes_init1 = om.create_initializer("/transformer/decoder/sin_cos_block/unsqueeze_axes1", np.array([3, 4], np.int64))
slice_axes_init = om.create_initializer("/transformer/decoder/sin_cos_block/slice_axes", np.array([4], np.int64))
slice_steps_init = om.create_initializer("/transformer/decoder/sin_cos_block/slice_steps", np.array([1], np.int64))
slice_starts_init1 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_starts1", np.array([0], np.int64))
slice_ends_init1 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_ends1", np.array([1], np.int64))
slice_starts_init2 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_steps2", np.array([1], np.int64))
slice_ends_init2 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_ends2", np.array([2], np.int64))
reshape_init = om.create_initializer("/transformer/decoder/sin_cos_block/reshape_dst_shape", np.array([1, 900, -1], np.int64))
for i, (gather_name, matmul_name) in enumerate(node_pairs):
gather_node = om.get_node(gather_name)
matmul_node = om.get_node(matmul_name)
# 【安全校验】:如果找不到这一对节点,说明不需要/无法优化这个 block,直接跳过
if gather_node is None or matmul_node is None:
continue
try:
next_node = om.get_next_nodes(gather_node)[0]
if next_node.op_type != "Mul": continue
mul_init_value = om.get_initializer_value(next_node.inputs[1])
if mul_init_value.size != 1: continue
next_node = om.get_next_nodes(next_node)[0]
if next_node.op_type != "Unsqueeze": continue
next_node.set_inputs([gather_node.inputs[0], unsqueeze_axes_init1.name])
next_node = om.get_next_nodes(next_node)[0]
if next_node.op_type != "Div": continue
div_init_value = om.get_initializer_value(next_node.inputs[1])
new_value = (div_init_value / mul_init_value).reshape(1, 1, 1, 64, 2)
new_init = om.create_initializer(next_node.name + "_B", new_value)
next_node.set_input(1, new_init.name)
next_nodes = om.get_next_nodes(next_node)
if len(next_nodes) != 2 or not all(x.op_type == 'Slice' for x in next_nodes): continue
sin_node, cos_node = None, None
for j, slice_node in enumerate(next_nodes):
slice_node.set_inputs([slice_node.inputs[0],
slice_starts_init1.name if j == 0 else slice_starts_init2.name,
slice_ends_init1.name if j == 0 else slice_ends_init2.name,
slice_axes_init.name,
slice_steps_init.name])
n_node = om.get_next_nodes(slice_node)[0]
if n_node.op_type == "Sin":
sin_node = n_node
elif n_node.op_type == "Cos":
cos_node = n_node
else:
raise RuntimeError("match fail!")
n_node = om.get_next_nodes(n_node)[0]
n_node = om.get_next_nodes(n_node)[0]
next_node = n_node # Concat node
if next_node.op_type != "Concat": continue
next_node.set_inputs([sin_node.outputs[0], cos_node.outputs[0]])
next_node.set_attribute("axis", 4)
next_node = om.get_next_nodes(next_node)[0]
if next_node.op_type != "Reshape": continue
next_node.set_input(1, reshape_init.name)
matmul_node.set_input(0, next_node.outputs[0])
if i == 0:
mm_b_value = om.get_initializer_value(matmul_node.inputs[1])
mm_b_value = np.concatenate([mm_b_value[128:256, ...],
mm_b_value[0:128, ...],
mm_b_value[256:, ...]],
axis=0)
om.set_initializer_value(matmul_node.inputs[1], mm_b_value)
except Exception as e:
# 如果匹配过程中发生任何形状或节点断层的意外,静默跳过这个 block
continue
om.update_map()
try:
om.infer_shape(strict_mode=False)
except:
pass
def fuse_one_attention(om: ONNXModifier, softmax_name: str, new_mask: bool = None, num_heads: int = 12):
softmax_node = om.get_node(softmax_name)
tmp_node = om.get_prev_nodes(softmax_node)[0]
assert tmp_node.op_type in ["MatMul", "Add"]
mask = None
if tmp_node.op_type == "Add":
mask_node = tmp_node
tmp_node = om.get_from_node(mask_node.inputs[0])
if tmp_node.op_type == "Div":
tmp_node = om.get_from_node(tmp_node.inputs[0])
assert tmp_node.op_type == "MatMul"
mask = mask_node.inputs[1]
assert new_mask is not None
tmp_node1 = om.get_from_node(tmp_node.inputs[0])
if tmp_node1.op_type == "Mul":
tmp_node1 = om.get_prev_nodes(tmp_node1)[0]
tmp_node2 = om.get_from_node(tmp_node.inputs[1])
assert tmp_node1.op_type == tmp_node2.op_type == "Transpose"
tmp_node1 = om.get_prev_nodes(tmp_node1)[0]
tmp_node2 = om.get_prev_nodes(tmp_node2)[0]
assert tmp_node1.op_type == tmp_node2.op_type == "Reshape"
q, k = tmp_node1.inputs[0], tmp_node2.inputs[0]
tmp_node = om.get_next_nodes(softmax_node)[0]
assert tmp_node.op_type == "MatMul"
tmp_node3 = om.get_from_node(tmp_node.inputs[1])
if tmp_node3 is not None:
assert tmp_node3.op_type == "Transpose"
tmp_node3 = om.get_prev_nodes(tmp_node3)[0]
assert tmp_node3.op_type == "Reshape"
v = tmp_node3.inputs[0]
else:
v_init = om.get_initializer(tmp_node.inputs[1])
v_init_value = om.get_initializer_value(tmp_node.inputs[1])
v_init_value = v_init_value[None, ...].transpose(0, 2, 1, 3)
B, S, H, D = v_init_value.shape
v_init_value = np.ascontiguousarray(v_init_value.reshape(B, S, H*D))
om.set_initializer_value(tmp_node.inputs[1], v_init_value)
v = v_init.name
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
mha_next_node = om.get_next_nodes(tmp_node)[0]
if mha_next_node.op_type == "Gemm":
gemm_next_node = om.get_next_nodes(mha_next_node)[0]
assert gemm_next_node.op_type == "Reshape"
reshape_next_node = om.get_next_nodes(gemm_next_node)[0]
assert reshape_next_node.op_type == "Add"
else:
assert mha_next_node.op_type == "MatMul"
name_prefix = '/'.join(softmax_name.split('/')[:-1])
mha_name = f"{name_prefix}/MultiHeadAttention"
mha_node = om.create_node("MultiHeadAttention",
mha_name,
[q, k, v] if mask is None else [q, k, v, new_mask],
[mha_name+'_output_0'],
num_heads=num_heads,
domain="com.microsoft",
index=mha_next_node.index-1)
mha_next_node.replace_input(mha_next_node.inputs[0], mha_node.outputs[0])
if mha_next_node.op_type == "Gemm":
weights = om.get_initializer_value(mha_next_node.inputs[1])
transB = mha_next_node.attrs["transB"]
assert transB == 1
weights = np.ascontiguousarray(weights.transpose(1, 0))
om.set_initializer_value(mha_next_node.inputs[1], weights)
new_matmul_name = mha_next_node.name.replace("Gemm", "MatMul(Gemm)")
new_matmul_node = om.create_node("MatMul",
new_matmul_name,
[mha_node.outputs[0], mha_next_node.inputs[1]],
[new_matmul_name + "_output_0"],
index=mha_next_node.index)
new_bias_name = mha_next_node.name.replace("Gemm", "Add(Gemm)")
new_add_node = om.create_node("Add",
new_bias_name,
[new_matmul_node.outputs[0], mha_next_node.inputs[2]],
[new_bias_name + "_output_0"],
index=new_matmul_node.index+1)
reshape_next_node.replace_input(gemm_next_node.outputs[0], new_add_node.outputs[0])
def optimize_normal_attention(om: ONNXModifier):
def create_new_attention_mask():
mask_next_node = om.get_to_nodes("attention_mask")[0]
cast_node = om.create_node("Cast",
"Cast_for_attention_mask",
["attention_mask"],
["Cast_for_attention_mask_output_0"],
to=1,
index=mask_next_node.index)
reducesum_node = om.create_node("ReduceSum",
"ReduceSum_for_mask",
[cast_node.outputs[0]],
["ReduceSum_for_mask_output_0"],
axes=1,
keepdims=0,
index=cast_node.index+1)
return reducesum_node.outputs[0]
# bert
# for i in range(12):
# fuse_one_attention(om, f"/bert/encoder/layer.{i}/attention/self/Softmax", "text_token_mask", num_heads=12)
new_mask = create_new_attention_mask()
for i in range(6):
# /transformer/encoder
# fuse_one_attention(om, f"/transformer/encoder/text_layers.{i}/self_attn/Softmax", "text_token_mask", num_heads=4)
# /transformer/decoder
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/self_attn/Softmax", new_mask, num_heads=8)
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/ca_text/Softmax", new_mask, num_heads=8)
om.update_map()
def optimize_backbone_attention(om: ONNXModifier):
def get_original_mask(mask_name, name_prefix):
mask_value = om.get_initializer_value(mask_name)
orig_mask = np.where(mask_value==0, 1, 0).astype(np.bool_)
orig_mask_init = om.create_initializer(f"{name_prefix}/mask", orig_mask)
return orig_mask_init.name
def _fuse_one_attention(softmax_name: str):
name_prefix = '/'.join(softmax_name.split('/')[:-1])
softmax_node = om.get_node(softmax_name)
tmp_node = om.get_prev_nodes(softmax_node)[0]
pos_bias_init = None
if tmp_node.op_type == "Reshape":
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Add"
pos_bias_init = om.get_initializer(tmp_node.inputs[1])
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Add"
mask = get_original_mask(tmp_node.inputs[1], name_prefix)
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "MatMul"
qk_matmul = tmp_node
tmp_node = om.get_from_node(qk_matmul.inputs[0])
assert tmp_node.op_type == "Mul"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Gather"
q_gather_node = tmp_node
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
reshape_node = tmp_node
tmp_node = om.get_from_node(qk_matmul.inputs[1])
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Gather"
k_gather_node = tmp_node
tmp_node = om.get_next_nodes(softmax_node)[0]
assert tmp_node.op_type == "MatMul"
v_gather_node = om.get_from_node(tmp_node.inputs[1])
assert v_gather_node.op_type == "Gather"
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
mha_out = tmp_node.outputs[0]
old_dst_shape = om.get_initializer_value(reshape_node.inputs[1])
b, s, _, h, d = old_dst_shape
new_dst_shape = [b, s, _, h*d]
new_dst_shape_init = om.create_initializer(f"{name_prefix}/qkv_hidden_states_shape",
np.array(new_dst_shape, np.int64))
reshape_node.set_input(1, new_dst_shape_init.name)
for node in [q_gather_node, k_gather_node, v_gather_node]:
node.set_input(0, reshape_node.outputs[0])
node.set_attribute("axis", 2)
mha_name = f"{name_prefix}/MultiHeadAttention"
inputs = [q_gather_node.outputs[0],
k_gather_node.outputs[0],
v_gather_node.outputs[0],
mask]
if pos_bias_init is not None:
inputs.append(pos_bias_init.name)
mha_node = om.create_node("MultiHeadAttention",
mha_name,
inputs,
[mha_name+'_output_0'],
num_heads=h,
domain="com.microsoft",
index=softmax_node.index)
mha_next_node = om.get_to_nodes(mha_out)[0]
mha_next_node.replace_input(mha_out, mha_node.outputs[0])
num_layers = 4
for l in range(num_layers):
num_blocks = 18 if l == 2 else 2
for b in range(num_blocks):
_fuse_one_attention(f"/backbone/backbone.0/layers.{l}/blocks.{b}/attn/softmax/Softmax")
def optimize_bidirect_attention(om: ONNXModifier):
for i in range(6):
reduce_max_name = f"/transformer/encoder/fusion_layers.{i}/attn/ReduceMax_1"
reduce_max_node = om.get_node(reduce_max_name)
# 【安全校验】
if reduce_max_node is None:
continue
next_nodes = om.get_next_nodes(reduce_max_node)
if not next_nodes:
continue
next_node = next_nodes[0]
if next_node.op_type != "Sub":
continue
name_prefix = '/'.join(reduce_max_name.split('/')[:-1])
matmul_name = f"{name_prefix}/identity_MatMul"
matmul_init = om.create_initializer(matmul_name + "_B",
np.diag(np.array([1] * 1)).astype(np.float32))
matmul_node = om.create_node(
"MatMul",
matmul_name,
[reduce_max_node.outputs[0], matmul_init.name],
[f"{matmul_name}_output_0"],
index = reduce_max_node.index+1
)
next_node.set_input(1, matmul_node.outputs[0])
# def main():
# input_onnx_path = sys.argv[1]
# output_onnx_path = sys.argv[2]
# # input_onnx_path = "ground_sim.onnx"
# # output_onnx_path = "ground_sim_0424_2nd.onnx"
# om = ONNXModifier(input_onnx_path)
# optimize_where_ndoes(om) # 1. 替换where节点
# optimize_transpose_nodes(om) # 2. 优化transpose节点
# optmize_sin_cos_block(om) # 3. 优化位置编码
# # om.add_opset_import("com.microsoft", 1)
# # optimize_normal_attention(om) # 4. 融合bert、transformer中的mha
# # optimize_ms_deform_attn(om) # 5. 融合多尺度可变形注意力
# # optimize_backbone_attention(om) # 6. 融合backbone中的注意力
# optimize_bidirect_attention(om) # 7. 优化双向注意力
# om.save(output_onnx_path, save_as_external_data=False)
def main():
# 假设你的原始模型路径
input_onnx_path = "../weights/ground_deform_sim.onnx"
# 优化后的模型输出路径
output_onnx_path = "../weights_opt/ground_deform_opt.onnx"
print(f"Loading ONNX model from {input_onnx_path}...")
om = ONNXModifier(input_onnx_path)
print("1. Optimizing Where nodes (Crucial for FP16 & MIGraphX)...")
optimize_where_ndoes(om)
print("2. Optimizing Transpose nodes...")
optimize_transpose_nodes(om)
# print("3. Optimizing Sin/Cos positional encoding...")
# optmize_sin_cos_block(om)
# print("4. Optimizing Bidirectional attention...")
# optimize_bidirect_attention(om)
print(f"Saving optimized model to {output_onnx_path}...")
om.save(output_onnx_path, save_as_external_data=False)
print("Optimization Done!")
if __name__ == "__main__":
main()
"""
onnx modifier: provide a conviennt way to modify onnx model
1. add node
2. remove node
3. modify node
4. query node
"""
from collections import defaultdict, deque
import os
import os.path as osp
import shutil
import tempfile
from typing import List, Dict, Set, Tuple, Optional, Union
import uuid
import warnings
import numpy as np
import onnx
from onnx import AttributeProto, numpy_helper
from onnx import shape_inference
from onnx.helper import make_attribute, make_node, make_opsetid, make_tensor, \
tensor_dtype_to_np_dtype
from onnxconverter_common import float16
import tqdm
# from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
SUPPORT_DTYPES = [
'BOOL', 'STRING', 'BFLOAT16', 'DOUBLE', 'FLOAT', 'FLOAT16',
'INT16', 'INT32', 'INT4', 'INT64', 'INT8', 'UINT16', 'UINT32', 'UINT4', 'UINT64', 'UINT8',
]
SUPPORT_DTYPES.extend([dt.lower() for dt in SUPPORT_DTYPES])
class Node:
def __init__(self, onnx_modifier=None, obj=None, index=None):
self.onnx_modifier = onnx_modifier
self.obj = obj
self.index = index
@property
def name(self):
return self.obj.name
@property
def op_type(self):
return self.obj.op_type
@property
def inputs(self):
return self.obj.input
@property
def outputs(self):
return self.obj.output
@property
def input_names(self):
return self.inputs
@property
def output_names(self):
return self.outputs
def check_modifier(self):
if self.onnx_modifier is None:
raise RuntimeError("onnx_modifier is not initialized")
@property
def prev_nodes(self):
self.check_modifier()
return self.onnx_modifier.get_prev_nodes(self)
@property
def next_nodes(self):
self.check_modifier()
return self.onnx_modifier.get_next_nodes(self)
def replace_input(self, old_name, new_name):
assert old_name in self.obj.input, \
f'"{old_name}" not in input name list of node named "{self.name}"'
for i, in_name in enumerate(self.obj.input):
if in_name == old_name:
self.set_input(i, new_name)
def set_input(self, index, name):
# assert index < len(self.obj.input), "index out of range"
# orig_name = self.obj.input[index]
# self.obj.input[index] = name
assert index < len(self.onnx_modifier.graph.node[self.index].input), "index out of range"
orig_name = self.onnx_modifier.graph.node[self.index].input[index]
self.onnx_modifier.graph.node[self.index].input[index] = name
self.check_modifier()
# Can not execute connection.pop_to_node() method directly.
# When node inputs contain multiple orig_name, need to remain the node in to_nodes.
if list(self.onnx_modifier.graph.node[self.index].input).count(orig_name) == 0:
self.onnx_modifier.connection_map[orig_name].pop_to_node(self)
if name not in self.onnx_modifier.connection_map:
self.onnx_modifier.connection_map[name] = Connection(name, self.onnx_modifier)
self.onnx_modifier.connection_map[name].add_to_node(self)
def set_inputs(self, names):
assert len(names) == len(self.obj.input), "number of inputs does not match"
assert all(isinstance(name, str) for name in names), "input names must be strings"
self.obj.input[:] = names
def set_output(self, index, name):
assert index < len(self.obj.output), "index out of range"
orig_name = self.obj.output[index]
self.obj.output[index] = name
self.check_modifier()
self.onnx_modifier.connection_map[orig_name].clear_from_node()
if name not in self.onnx_modifier.connection_map:
self.onnx_modifier.connection_map[name] = Connection(name, self.onnx_modifier)
self.onnx_modifier.connection_map[name].set_from_node(self)
def set_outputs(self, names):
assert len(names) == len(self.obj.output), "number of outputs does not match"
assert all(isinstance(name, str) for name in names), "output names must be strings"
self.obj.output[:] = names
@property
def attrs(self):
attrs = {}
for attr in self.obj.attribute:
if attr.type == AttributeProto.FLOAT: # 1
value = attr.f
elif attr.type == AttributeProto.INT: # 2
value = attr.i
elif attr.type == AttributeProto.STRING: # 3
value = attr.s.decode('utf-8')
elif attr.type == AttributeProto.TENSOR: # 4
value = numpy_helper.to_array(attr.t)
elif attr.type == AttributeProto.FLOATS: # 6
value = list(attr.floats)
elif attr.type == AttributeProto.INTS: # 7
value = list(attr.ints)
else:
value = f"Unsupported type: {attr.type}"
attrs[attr.name] = value
return attrs
def set_attribute(self, name, value, name2attr=None):
if not name2attr:
name2attr = {}
for attr in self.obj.attribute:
name2attr[attr.name] = attr
if name in name2attr:
if isinstance(value, float):
name2attr[name].f = value
name2attr[name].type = AttributeProto.FLOAT
elif isinstance(value, int):
name2attr[name].i = value
name2attr[name].type = AttributeProto.INT
elif isinstance(value, str):
name2attr[name].s = value.encode('utf-8')
name2attr[name].type = AttributeProto.STRING
elif isinstance(value, np.ndarray):
name2attr[name].ClearField("t")
name2attr[name].t.CopyFrom(numpy_helper.from_array(value))
elif isinstance(value, list):
is_all_float = all(isinstance(x, float) for x in value)
is_all_int = all(isinstance(x, int) for x in value)
assert is_all_float or is_all_int
if is_all_float:
name2attr[name].ClearField("floats")
name2attr[name].floats.extend(value)
name2attr[name].type = AttributeProto.FLOATS
else:
name2attr[name].ClearField("ints")
name2attr[name].ints.extend(value)
name2attr[name].type = AttributeProto.INTS
else:
if isinstance(value, np.ndarray):
value = numpy_helper.from_array(value)
self.obj.attribute.append(make_attribute(name, value))
def set_attributes(self, attr_dict):
name2attr = {}
for attr in self.obj.attribute:
name2attr[attr.name] = attr
for name, value in attr_dict.items():
self.set_attribute(name, value, name2attr)
class Connection:
def __init__(self, conn_name, onnx_modifier=None):
self.name = conn_name
self.onnx_modifier = onnx_modifier
self.from_node = None
self.to_nodes = []
self.to_node_names = set()
def check_modifier(self):
if self.onnx_modifier is None:
raise RuntimeError("onnx_modifier is not initialized")
def set_from_node(self, node: str | Node):
if isinstance(node, str):
self.check_modifier()
_node = self.onnx_modifier.get_node(Node)
assert node is not None, f'No node named "{node}" in onnx graph!'
elif isinstance(node, Node):
_node = node
else:
raise TypeError(f"Connection.set_from_node except input argument type" \
f" is str or Node, but received: {type(node)}")
self.from_node = _node
def clear_from_node(self):
self.from_node = None
def add_to_node(self, node: str | Node):
if isinstance(node, str):
_name = node
self.check_modifier()
_node = self.onnx_modifier.get_node(Node)
assert node is not None, f'No node named "{node}" in onnx graph!'
elif isinstance(node, Node):
_name = node.name
_node = node
else:
raise TypeError(f"Connection.add_to_node except input argument type" \
f" is str or Node, but received: {type(node)}")
if _name not in self.to_node_names:
self.to_node_names.add(_name)
self.to_nodes.append(_node)
def pop_to_node(self, node: str | Node):
if isinstance(node, str):
_name = node
self.check_modifier()
_node = self.onnx_modifier.get_node(Node)
assert node is not None, f'No node named "{node}" in onnx graph!'
elif isinstance(node, Node):
_name = node.name
_node = node
else:
raise TypeError(f"Connection.pop_to_node except input argument type" \
f" is str or Node, but received: {type(node)}")
if _name not in self.to_node_names:
raise ValueError(f'Node "{_name}" not in target nodes of connction "{self.name}"!')
self.to_node_names.remove(_name)
for i in range(len(self.to_nodes)):
if self.to_nodes[i].name == _name:
return self.to_nodes.pop(i)
else:
raise RuntimeError("to_nodes dismatch to_node_names!")
class ONNXModifier:
def __init__(self, onnx_path):
self.onnx_path = onnx_path
self.node_map = {}
self.initializer_map = {}
self.sparse_initializer_map = {}
self.connection_map = {}
self.value_info_map = {}
self.parse_onnx(self.onnx_path)
def parse_onnx(self, onnx_path):
model = onnx.load(onnx_path)
self.model = model
self.domain = model.domain
self.graph = model.graph
self.ir_version = model.ir_version
self.mdoel_version = model.model_version
self.opset_import = model.opset_import
self.update_map()
def add_node_name_if_nameless(self, node: Node):
if not hasattr(self, "node_suffixes"):
self.name_suffixes = set()
if node.name == "" or node.name == None:
suffix = None
while True:
suffix = uuid.uuid4().hex[:8]
if suffix not in self.name_suffixes:
break
node.obj.name = node.op_type + "_" + suffix
def add_opset_import(self, domain: str, version: int):
self.model.opset_import.append(make_opsetid(domain, version))
@property
def input_names(self):
return [i.name for i in self.graph.input]
@property
def output_names(self):
return [o.name for o in self.graph.output]
def add_input(self, name, dtype='float32', shape=None):
assert dtype in set(SUPPORT_DTYPES)
self.create_value_info(name, dtype=dtype, shape=shape)
new_input = self.value_info_map.pop(name)
_new_input = self.graph.value_info.pop()
assert id(new_input) == id(_new_input)
assert name == new_input.name
self.graph.input.append(new_input)
return new_input
def add_output(self, name, new_name=None, shape=None):
if name not in self.value_info_map:
raise ValueError(f"{name} not in onnx_modifier.value_info_map")
index = None
for i, v in enumerate(self.graph.value_info):
if v.name == name:
index = i
break
else:
raise ValueError(f"{name} not in model.graph.value_info")
value_info = self.value_info_map.pop(name)
assert value_info.name == name
assert id(value_info) == id(self.graph.value_info[index])
self.graph.value_info.pop(index)
if shape is not None:
tensor_type = onnx.helper.make_tensor_type_proto(
elem_type=value_info.type.tensor_type.elem_type,
shape=shape
)
value_info.type.CopyFrom(tensor_type)
if new_name is None:
self.graph.output.append(value_info)
else:
from_node = self.get_from_node(name)
to_nodes = self.get_to_nodes(name)
for i, output_name in enumerate(from_node.output_names):
if output_name == name:
from_node.set_output(i, new_name)
for node in to_nodes:
node.replace_input(name, new_name)
value_info.name = new_name
self.graph.output.append(value_info)
def remove_output(self, name):
"""根据名称删除模型输出"""
assert name in self.output_names
# print("need remove output name:", name)
index = None
for i, out in enumerate(self.graph.output):
# print(f"current(index={i}) output name:", out.name)
if out.name == name:
index = i
break
else:
raise RuntimeError(f"ONNX graphx not has a output named '{name}'.")
self.graph.output.pop(index)
def get_node(self, name_or_index: Union[str, int]):
"""根据节点名称或索引获取节点实例"""
if isinstance(name_or_index, str):
if name_or_index in self.node_map:
return self.node_map.get(name_or_index, None)
elif isinstance(name_or_index, int):
if name_or_index < len(self.graph.node):
return self.node_map.get(
self.graph.node[name_or_index].name, None)
else:
raise ValueError(f"Node index {name_or_index} out of range")
def get_nodes(self, *op_types: str):
"""根据节点类型获取节点实例"""
assert len(op_types) >= 1
op_types_set = set(op_types)
node_names = [node.name for node in self.graph.node if node.op_type in op_types_set]
nodes = [self.node_map[name] for name in node_names]
return nodes
def get_initializer(self, name: str):
"""根据initializer名称获取initializer"""
return self.initializer_map.get(name)
def get_connection(self, name: str):
"""根据边名称获取边"""
return self.connection_map.get(name)
def get_from_node(self, conn: Union[str, Connection]):
"""获取某条边的输入节点名"""
if isinstance(conn, str):
assert conn in self.connection_map, f"Connection {conn} not in connection_map!"
return self.connection_map[conn].from_node
elif isinstance(conn, Connection):
return conn.from_node
else:
raise TypeError(f"Invalid connection type {type(conn)}")
def get_to_nodes(self, conn: Union[str, Connection]):
"""获取某条边的输出节点"""
if isinstance(conn, str):
assert conn in self.connection_map, f"Connection {conn} not in connection_map!"
return self.connection_map[conn].to_nodes
elif isinstance(conn, Connection):
return conn.to_nodes
else:
raise TypeError(f"Invalid connection type {type(conn)}")
def get_prev_nodes(self, node: Union[str, Node]):
"""获取某节点的上游输入节点"""
if isinstance(node, str):
node = self.node_map[node]
elif isinstance(node, Node):
pass
else:
raise TypeError(f"Invalid node type {type(node)}")
nodes = []
for conn_name in node.inputs:
from_node = self.get_from_node(conn_name)
if from_node:
nodes.append(from_node)
return nodes
def get_next_nodes(self, node: Union[str, Node]):
"""获取某节点的下游节点"""
if isinstance(node, str):
node = self.node_map[node]
elif isinstance(node, Node):
pass
else:
raise TypeError(f"Invalid node type {type(node)}")
nodes = []
for conn_name in node.outputs:
to_nodes = self.get_to_nodes(conn_name)
nodes.extend(to_nodes)
return nodes
def create_node(self, op_type, op_name, inputs, outputs, doc_string=None,
domain=None, index=None, **attrs):
"""创建一个新节点"""
onnx_node = make_node(op_type, inputs, outputs, name=op_name,
doc_string=doc_string, domain=domain, **attrs)
if index is None:
self.graph.node.append(onnx_node)
index = len(self.graph.node) - 1
else:
assert index <= len(self.graph.node), "index out of range"
self.graph.node.insert(index, onnx_node)
for i in range(index + 1, len(self.graph.node)):
node_name = self.graph.node[i].name
old_idx = self.node_map[node_name].index
assert old_idx == i - 1, \
f"Node {node_name} index conflict: {old_idx} != {i - 1}"
self.node_map[node_name].index = i
new_node = Node(self, self.graph.node[index], index)
self.node_map[op_name] = new_node
for in_name in new_node.input_names:
if in_name not in self.value_info_map:
self.create_value_info(in_name, dtype="float")
if in_name not in self.connection_map:
self.connection_map[in_name] = Connection(in_name, self)
self.connection_map[in_name].add_to_node(new_node)
for out_name in new_node.output_names:
if out_name not in self.value_info_map:
self.create_value_info(out_name, dtype="float")
if out_name not in self.connection_map:
self.connection_map[out_name] = Connection(out_name, self)
self.connection_map[out_name].set_from_node(new_node)
return new_node
def create_initializer(self, name, value: np.ndarray):
"""创建一个 initializer"""
assert name not in self.initializer_map, f"initializer {name} already exists!"
init_node = numpy_helper.from_array(value, name=name)
use_external_data = value.nbytes / 1024 / 1024 / 1024 > 2
if use_external_data:
print("use external data:", name)
init_node.data_location = onnx.TensorProto.EXTERNAL
location = name.replace('/', '+') + '.data'
onnx.external_data_helper.set_external_data(init_node, location)
with tempfile.TemporaryDirectory() as tmp_dir:
onnx.external_data_helper.save_external_data(init_node, tmp_dir)
init_node.ClearField("raw_data")
self.graph.initializer.append(init_node)
onnx.external_data_helper.load_external_data_for_tensor(
self.graph.initializer[-1], tmp_dir)
del self.graph.initializer[-1].external_data[:]
self.graph.initializer[-1].ClearField("data_location")
else:
self.graph.initializer.append(init_node)
self.initializer_map[name] = self.graph.initializer[-1]
return self.graph.initializer[-1]
def create_value_info(self, name, dtype=None, shape=None):
if dtype is None:
elem_type = None
else:
assert isinstance(dtype, str)
assert dtype in set(SUPPORT_DTYPES)
elem_type = getattr(onnx.TensorProto, dtype.upper())
value_info = onnx.helper.make_tensor_value_info(name=name,
elem_type=elem_type,
shape=shape)
self.graph.value_info.append(value_info)
self.value_info_map[name] = self.graph.value_info[-1]
return self.graph.value_info[-1]
def get_initializer_value(self, name):
"""获取initializer的数值"""
init = self.get_initializer(name)
return numpy_helper.to_array(init)
def set_initializer_value(self, name, value: np.ndarray):
"""为initializer设置新的数值"""
init = self.get_initializer(name)
# 检查形状和类型
old_shape = list(init.dims)
new_shape = list(value.shape)
# old_dtype = TENSOR_TYPE_TO_NP_TYPE.get(init.data_type, None)
old_dtype = tensor_dtype_to_np_dtype(init.data_type)
new_dtype = value.dtype
if old_shape != new_shape:
warn_message = f"Initailizer {name} shape changed: {old_shape} -> {new_shape}"
warnings.warn(warn_message, RuntimeWarning)
if old_dtype is not None and old_dtype != new_dtype:
warn_message = f"Initailizer {name} dtype changed: {old_dtype} -> {new_dtype}"
warnings.warn(warn_message, RuntimeWarning)
new_tensor_proto = numpy_helper.from_array(value, name=name)
init.CopyFrom(new_tensor_proto)
def connect_node(self, node, inputs_map, outputs_map):
"""将某个节点与其上下游节点连接起来
Args:
node: Node
inputs_map: [(node0, out_idx0), (node1, out_idx1), ...]
outputs_map: [(node0, in_idx0), (node1, in_idx1), ...]
"""
# 在连接 A -> B 时,若 A 的输出名与 B 的输入名冲突时,优先使用 A 的输出名,
# 即:B.input[i] = A.output[j]
for i, (n, j) in enumerate(inputs_map):
if isinstance(n, str):
n = self.node_map[n]
assert j < len(n.outputs), \
f"output index {i} out of node {n.name} outputs range"
node.set_input(i, n.outputs[j])
for name, (n, i) in zip(node.outputs, outputs_map):
if isinstance(n, str):
n = self.node_map[n]
assert i < len(n.outputs), \
f"output index {i} out of node {n.name} outputs range"
n.set_output(i, name)
# TODO: update self.connection_map
def pop_node(self, node: Union[str, Node, int], auto_connect=True):
"""根据节点名称或索引移除节点"""
if isinstance(node, str):
node = self.node_map.get(node, None)
if node is None:
return None
index = node.index
assert node.name == self.graph.node[index].name
elif isinstance(node, int):
if node >= len(self.graph.node):
raise ValueError(f"Node index {node} out of range")
index = node
elif isinstance(node, Node):
index = node.index
else:
raise ValueError(f"Invalid node name or index: {node}")
for i in range(index + 1, len(self.graph.node)):
node = self.graph.node[i]
self.node_map[node.name].index -= 1
# print(f"node_name={self.graph.node[index].name} node_index={index}")
_node_obj = self.graph.node[index]
_node = self.get_node(_node_obj.name)
next_nodes = self.get_next_nodes(_node)
self.graph.node.pop(index)
self.node_map.pop(_node_obj.name)
# automatic connecting edges
if auto_connect and len(_node.inputs) == 1 and len(_node.outputs) == 1:
# self.connection_map[_node.inputs[0]].pop_to_node(_node)
for next_node in next_nodes:
next_node.replace_input( _node.outputs[0], _node.inputs[0])
# self.connection_map[_node.inputs[0]].add_to_node(next_node)
# self.connection_map.pop(_node.outputs[0])
# update connection_map
for in_name in _node.input_names:
if _node.name in self.connection_map[in_name].to_node_names:
self.connection_map[in_name].pop_to_node(_node)
for i, out_name in enumerate(_node.output_names):
self.connection_map[out_name].clear_from_node()
return _node
def remove_nodes(self, nodes: List[str | Node], auto_connect=False):
"""同时删除多个节点"""
indices = set()
_nodes = []
invalid_nodes = set()
for node in nodes:
if isinstance(node, str):
if node in self.node_map:
node = self.node_map[node]
if node.index not in indices:
_nodes.append(node)
indices.add(node.index)
else:
invalid_nodes.add(node)
elif isinstance(node, Node):
if node.index not in indices:
_nodes.append(node)
indices.add(node.index)
else:
invalid_nodes.add(node)
_nodes.sort(key=lambda x:x.index, reverse=True)
use_progress_bar = len(_nodes) > 500
if use_progress_bar:
pbar = tqdm.tqdm(total=len(_nodes), desc="Removing nodes")
for node in _nodes:
self.pop_node(node, auto_connect=auto_connect)
if use_progress_bar:
pbar.update(1)
if use_progress_bar:
pbar.close()
# print(f"{len(nodes) - len(invalid_nodes)} nodes have been removed.")
# if len(invalid_nodes) > 0:
# print(f"find {len(invalid_nodes)} invalid nodes:\n", invalid_nodes)
def pop_initializer(self, init_name: str, update_node_inputs: bool = True):
"""根据initializer名字移除initializer"""
_init1 = self.initializer_map.pop(init_name)
init_index = None
for i in range(len(self.graph.initializer)):
if self.graph.initializer.name == init_name:
init_index = i
break
else:
raise ValueError(f"Not existing a Initializer named {init_name}")
_init2 = self.graph.initializer.pop(init_index)
assert id(_init1) == id(_init2)
# if update_node_inputs and init_name in self.connection_map:
# to_nodes = self.get_to_nodes(init_name)
# self.connection_map.pop(init_name)
# for node in to_nodes:
# num_inputs = len(node.inputs)
# for i in range(num_inputs-1, -1, -1):
# if node.inputs[i] == init_name:
# node.inputs.pop(i)
return _init1
def update_map(self):
"""更新connection_map与node_map"""
self.node_map.clear()
self.connection_map.clear()
self.initializer_map.clear()
self.sparse_initializer_map.clear()
self.value_info_map.clear()
for i, node in enumerate(self.graph.node):
new_node = Node(self, node, i)
self.add_node_name_if_nameless(new_node)
self.node_map[node.name] = new_node
for conn_name in node.input:
if conn_name not in self.connection_map:
self.connection_map[conn_name] = Connection(conn_name, self)
self.connection_map[conn_name].add_to_node(new_node)
for conn_name in node.output:
if conn_name not in self.connection_map:
self.connection_map[conn_name] = Connection(conn_name, self)
self.connection_map[conn_name].set_from_node(new_node)
for i, node in enumerate(self.graph.initializer):
self.initializer_map[node.name] = node
for i, node in enumerate(self.graph.sparse_initializer):
self.sparse_initializer_map[node.name] = [node, i]
for i, conn in enumerate(self.graph.value_info):
self.value_info_map[conn.name] = conn
def find_unuseful_nodes(self):
"""寻找没有用到的节点"""
end_names = set()
for output_name in self.output_names:
end_names.add(self.get_from_node(output_name).name)
unuseful_names = set()
for node in self.node_map.values():
if node.name in end_names:
continue
next_nodes = self.get_next_nodes(node)
if len(next_nodes) == 0:
unuseful_names.add(node.name)
model_output_names = set(self.output_names)
q = deque([self.node_map[name] for name in unuseful_names])
while len(q) != 0:
node = q.popleft()
prev_nodes = self.get_prev_nodes(node)
for node1 in prev_nodes:
next_nodes = self.get_next_nodes(node1)
next_names = set([node2.name for node2 in next_nodes])
# if (next_names - end_names).issubset(unuseful_names):
if next_names.issubset(unuseful_names):
if node1.name not in unuseful_names and set(node1.output_names).isdisjoint(model_output_names):
q.append(node1)
unuseful_names.add(node1.name)
unuseful_nodes = [self.node_map[name] for name in unuseful_names]
return unuseful_nodes
def remove_trash(self):
"""
1. 移除无用的节点
2. 移除无用的initializer
3. 移除没有输入节点的connection
4. 移除没有用到的模型输入与输出
5. 移除没有用到的value_info
"""
self.update_map()
unuseful_nodes = self.find_unuseful_nodes()
print(f"Find unuseful {len(unuseful_nodes)} nodes!")
for i, node in enumerate(unuseful_nodes):
print(f"remove unuseful node {i}:", node.name)
self.remove_nodes(unuseful_nodes)
self.update_map()
all_node_inputs = set()
for node in self.node_map.values():
all_node_inputs.update(node.input_names)
# remove unuseful initializers
cnt = 0
for init_name in list(self.initializer_map.keys()):
if init_name in all_node_inputs:
continue
index = None
for i, init in enumerate(self.graph.initializer):
if init.name == init_name:
index = i
break
else:
raise ValueError(
f"{init_name} not in model.graph.initializer")
print(f"remove unuseful initializer {cnt}:", init_name)
self.graph.initializer.pop(index)
cnt += 1
# remove unuseful sparse_initializers
cnt = 0
for init_name in list(self.sparse_initializer_map.keys()):
if init_name in all_node_inputs:
continue
index = None
for i, init in enumerate(self.graph.sparse_initializer):
if init.name == init_name:
index = i
break
else:
raise ValueError(
f"{init_name} not in model.graph.sparse_initializer")
print(f"remove unuseful sparse initializer {cnt}:", init_name)
self.graph.sparse_initializer.pop(index)
cnt += 1
self.update_map()
# remove unuseful inputs and outputs
for in_name in self.input_names:
# print(in_name, [n.name for n in self.get_to_nodes(in_name)])
if len(self.get_to_nodes(in_name)) != 0:
continue
for i, _in in enumerate(self.graph.input):
if in_name == _in.name:
self.graph.input.pop(i)
break
for out_name in self.output_names:
# print(out_name, self.get_from_node(out_name).name)
if self.get_from_node(out_name) is not None:
continue
for i, _out in enumerate(self.graph.output):
if out_name == _out.name:
self.graph.output.pop(i)
break
self.update_map()
# remove unuseful value_info
cnt = 0
num_value_info = len(self.graph.value_info)
for i in range(num_value_info-1, -1, -1):
v = self.graph.value_info[i]
if v.name not in self.connection_map:
self.graph.value_info.pop(i)
print(f"remove unuseful value_info {cnt}:", v.name)
cnt += 1
self.update_map()
def infer_shape(self, strict_mode=False):
for vi in self.graph.value_info:
if vi.type.HasField("tensor_type"):
vi.type.tensor_type.ClearField("shape")
model = shape_inference.infer_shapes(self.model, strict_mode=strict_mode)
self.model = model
self.domain = model.domain
self.graph = model.graph
self.ir_version = model.ir_version
self.mdoel_version = model.model_version
self.opset_import = model.opset_import
self.update_map()
def infer_node_shpe(self, node):
input_shapes = []
input_dtypes = []
for input_name in node.inputs:
value_info = self.value_info_map[input_name]
input_shapes.append(value_info.type.tensor_type.dims)
input_dtypes.append(value_info.type.tensor_type.type)
shape_inference.infer_node_outputs(node.obj, input_shapes, input_dtypes)
def convert_float_to_float16(self):
self.model = float16.convert_float_to_float16(self.model, keep_io_types=True)
def save(self, save_path, save_as_external_data=False,
all_tensors_to_one_file=True):
self.remove_trash()
external_data_name = osp.basename(save_path) + '.data'
external_data_path = osp.join(osp.dirname(save_path), external_data_name)
if save_as_external_data and osp.isfile(external_data_path):
os.remove(external_data_path)
onnx.save(self.model,
save_path,
save_as_external_data=save_as_external_data,
all_tensors_to_one_file=all_tensors_to_one_file,
location=external_data_name,
size_threshold=1024,
convert_attribute=False)
import onnx
from onnxsim import simplify
from onnxconverter_common import float16
onnx_model_path = "./weights/ground.onnx"
sim_model_path = "./weights/ground_sim.onnx"
print("1️⃣ 正在进行 ONNX Simplify...")
model = onnx.load(onnx_model_path)
model_simp, check = simplify(model)
if check:
onnx.save(model_simp, sim_model_path)
print(f"✅ Simplify 完成!已保存至 {sim_model_path}")
else:
print("❌ Simplify 验证失败!")
exit()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment