Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zk
GroundingDINO-DCU-Optimized
Commits
a1865640
Commit
a1865640
authored
Apr 27, 2026
by
zk
Browse files
新增migraphx部分
parent
0896d47e
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2652 additions
and
68 deletions
+2652
-68
README.md
README.md
+38
-8
deform_ort/onnx_inference_deform_optim.py
deform_ort/onnx_inference_deform_optim.py
+21
-6
deform_ort/onnx_inference_deform_optim_iobinding.py
deform_ort/onnx_inference_deform_optim_iobinding.py
+32
-15
deform_ort/onnx_optimize.py
deform_ort/onnx_optimize.py
+21
-22
deform_ort/profile_analyzer.py
deform_ort/profile_analyzer.py
+62
-0
deform_ort/result.jpg
deform_ort/result.jpg
+0
-0
groundingdino/util/inference.py
groundingdino/util/inference.py
+2
-2
images/out/result.jpg
images/out/result.jpg
+0
-0
migraphx_infer/migraphx_export.bash
migraphx_infer/migraphx_export.bash
+6
-0
migraphx_infer/migraphx_infer.bash
migraphx_infer/migraphx_infer.bash
+0
-11
migraphx_infer/migraphx_infer.py
migraphx_infer/migraphx_infer.py
+4
-4
migraphx_infer/migraphx_infer1.py
migraphx_infer/migraphx_infer1.py
+242
-0
migraphx_infer/migraphx_perf.bash
migraphx_infer/migraphx_perf.bash
+4
-0
migraphx_infer/modify_onnx.py
migraphx_infer/modify_onnx.py
+611
-0
migraphx_infer/modify_onnx1.py
migraphx_infer/modify_onnx1.py
+693
-0
migraphx_infer/onnx_modifier.py
migraphx_infer/onnx_modifier.py
+899
-0
migraphx_infer/onnx_sim.py
migraphx_infer/onnx_sim.py
+17
-0
No files found.
README.md
View file @
a1865640
...
...
@@ -157,7 +157,7 @@ python onnx_inference_deform_optim.py
如需使用更低分辨率的图像输入(如 400x800)以进一步加速推理,可按以下步骤操作:
### 6.
1 修改导出脚本
1
.
修改导出脚本
编辑
`deform_ort/export_onnx_deform.py`
,修改图像尺寸与导出路径:
...
...
@@ -169,7 +169,7 @@ img = torch.randn(1, 3, 400, 800).to(device)
onnx_output_path
=
"../weights_400x800/ground_deform.onnx"
```
### 6.
2 正常导出并量化
2
.
正常导出并量化
```
bash
cd
deform_ort
...
...
@@ -177,7 +177,7 @@ python export_onnx_deform.py
python onnx_optimize.py
```
### 6.
3 修改推理预处理分辨率
3
.
修改推理预处理分辨率
编辑
`groundingdino/util/inference.py`
中的
`load_image`
函数,将
`RandomResize`
的参数从 800 改为 400:
...
...
@@ -186,7 +186,7 @@ python onnx_optimize.py
T
.
RandomResize
([
400
],
max_size
=
1333
),
```
### 6.
4. 执行 ORT 推理
4.
执行 ORT 推理
运行推理脚本,并确保代码中的 ONNX 模型路径指向
`weights_400x800/`
下对应的模型文件:
...
...
@@ -198,7 +198,26 @@ python onnx_inference_deform_optim.py
-----
## 7\. 测试结果对比
## 7\. migraphx推理
1.
进入migraphx_infer文件夹
```
bash
cd
migraphx_infer
```
2.
运行转换onnx脚本
将简化后的onnx转换为要用migraphx推理的onnx
```
bash
bash migraphx_export.bash
```
3.
如果已经得到了mxr文件,直接测试
```
bash
bash migraphx_perf.bash
```
-----
## 8\. 测试结果对比
*以下测试均包含 5 轮预热(Warmup)和 10 轮正式测试。*
...
...
@@ -208,7 +227,7 @@ python onnx_inference_deform_optim.py
> * **模型文件**:默认存放于 `../weights/` 目录下。
> * **自定义算子目录**:对应的完整动态库路径均为 `../[目录名]/build/libms_deform_attn_ort.so`。
###
7
.1 BW150 测试结果
###
8
.1
ORT
BW150 测试结果
单张 BW150 卡,图像输入 800x1200,Batch Size = 1
...
...
@@ -221,7 +240,7 @@ python onnx_inference_deform_optim.py
|
**ORT + Plugin**
| +自定义算子
<br>
+FP16 纯量化方案 B |
`ground_deform_fp16_all.onnx`
|
`ort_plugin_fp16_B`
| 87.34 | 11.44 |
|
**ORT + Plugin**
| +自定义算子
<br>
+FP16 极致优化方案 C |
`ground_deform_fp16_all.onnx`
|
`ort_plugin_fp16_C`
| 84.52 | 11.82 |
###
7
.2 BW100 测试结果
###
8
.2
ORT
BW100 测试结果
单张 BW100 卡,图像输入 800x1200,Batch Size = 1
...
...
@@ -233,11 +252,22 @@ python onnx_inference_deform_optim.py
|
**ORT + Plugin**
| +自定义算子
<br>
+FP16 纯量化方案 B |
`ground_deform_fp16_all.onnx`
|
`ort_plugin_fp16_B`
| 105.35 | 9.49 |
|
**ORT + Plugin**
| +自定义算子
<br>
+FP16 极致优化方案 C |
`ground_deform_fp16_all.onnx`
|
`ort_plugin_fp16_C`
| 100.91 | 9.90 |
### 8.3 migraphx BW100 测试结果
```
Batch size: 1
Rate: 6.05197 inferences/sec
Total time: 165.235ms (Min: 165.115ms, Max: 165.535ms,
Mean: 165.258ms, Median: 165.225ms)
Percentiles (90%, 95%, 99%): (165.358ms, 165.358ms, 165.358ms)
Total instructions time: 205.275ms
Overhead time: 2.32812ms, -40.0399ms
Overhead: 1%, -24%
```
-----
## 参考项目
本项目在开发过程中参考了以下
优秀
开源项目
,在此表示感谢
:
本项目在开发过程中参考了以下开源项目:
-
[
**GroundingDINO**
](
https://github.com/IDEA-Research/GroundingDINO
)
- GroundingDINO 官方仓库,提供基础模型与算法实现。
-
[
**GroundingDINO-TensorRT-and-ONNX-Inference**
](
https://github.com/wingdzero/GroundingDINO-TensorRT-and-ONNX-Inference
)
- 提供了 GroundingDINO 的 TensorRT 及 ONNX 推理部署参考实现。
\ No newline at end of file
deform_ort/onnx_inference_deform_optim.py
View file @
a1865640
...
...
@@ -7,20 +7,35 @@ import onnxruntime as ort
import
bisect
import
time
import
os
from
typing
import
Tuple
import
groundingdino.datasets.transforms
as
T
from
PIL
import
Image
"""
针对模型前后处理和代码结构进行优化
1.预测结果获取优化prediction_logits = sigmoid(outputs[0][0])
2.输入数据提前获取直接传入,移除了对tokenizer的依赖
"""
from
groundingdino.util.inference
import
load_image
so_options
=
ort
.
SessionOptions
()
custom_op_lib_path
=
"../ort_plugin/build/libms_deform_attn_ort.so"
custom_op_lib_path
=
"../ort_plugin
_fp16_C
/build/libms_deform_attn_ort.so"
so_options
.
register_custom_ops_library
(
custom_op_lib_path
)
# 开启ort优化
so_options
.
graph_optimization_level
=
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_ALL
def
load_image
(
image_path
:
str
)
->
Tuple
[
np
.
array
,
torch
.
Tensor
]:
transform
=
T
.
Compose
(
[
T
.
RandomResize
([
800
],
max_size
=
1333
),
# T.RandomResize([400], max_size=1333),
T
.
ToTensor
(),
T
.
Normalize
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
]),
]
)
image_source
=
Image
.
open
(
image_path
).
convert
(
"RGB"
)
image
=
np
.
asarray
(
image_source
)
image_transformed
,
_
=
transform
(
image_source
,
None
)
return
image
,
image_transformed
def
sigmoid
(
x
):
return
1
/
(
1
+
np
.
exp
(
-
x
))
...
...
@@ -180,7 +195,7 @@ def benchmark_performance(
if
__name__
==
'__main__'
:
# 配置参数
model_path
=
'../weights
_400x600
/ground_deform.onnx'
model_path
=
'../weights/ground_deform
_fp16_all
.onnx'
"""
../weights/ground_deform.onnx 普通版本
../weights/ground_deform_sim.onnx 简化版本
...
...
@@ -264,6 +279,6 @@ if __name__ == '__main__':
)
# 保存结果
cv2
.
imwrite
(
'./result.jpg'
,
ori_img
)
print
(
f
"
\n
✅ 结果已保存至: ./result.jpg"
)
cv2
.
imwrite
(
'.
./weights
/result.jpg'
,
ori_img
)
print
(
f
"
\n
✅ 结果已保存至: .
./weights
/result.jpg"
)
print
(
f
"✅ 检测到目标:
{
phrases
}
(共
{
len
(
boxes
)
}
个)"
)
deform_ort/onnx_inference_deform_optim_iobinding.py
View file @
a1865640
...
...
@@ -7,6 +7,9 @@ import onnxruntime as ort
import
bisect
import
time
import
os
from
typing
import
Tuple
import
groundingdino.datasets.transforms
as
T
from
PIL
import
Image
"""
针对模型前后处理和代码结构进行优化
1.预测结果获取优化prediction_logits = sigmoid(outputs[0][0])
...
...
@@ -14,14 +17,28 @@ import os
3.IO binding优化
"""
from
groundingdino.util.inference
import
load_image
so_options
=
ort
.
SessionOptions
()
custom_op_lib_path
=
"../ort_plugin/build/libms_deform_attn_ort.so"
# 如何想要查看ORT的详细日志,可以取消下面这行的注释,并设置合适的日志级别
# so_options.enable_profiling = True
custom_op_lib_path
=
"../ort_plugin_fp16_C/build/libms_deform_attn_ort.so"
so_options
.
register_custom_ops_library
(
custom_op_lib_path
)
# 开启ort优化
so_options
.
graph_optimization_level
=
ort
.
GraphOptimizationLevel
.
ORT_ENABLE_ALL
def
load_image
(
image_path
:
str
)
->
Tuple
[
np
.
array
,
torch
.
Tensor
]:
transform
=
T
.
Compose
(
[
T
.
RandomResize
([
800
],
max_size
=
1333
),
# T.RandomResize([400], max_size=1333),
T
.
ToTensor
(),
T
.
Normalize
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
]),
]
)
image_source
=
Image
.
open
(
image_path
).
convert
(
"RGB"
)
image
=
np
.
asarray
(
image_source
)
image_transformed
,
_
=
transform
(
image_source
,
None
)
return
image
,
image_transformed
def
sigmoid
(
x
):
return
1
/
(
1
+
np
.
exp
(
-
x
))
...
...
@@ -67,20 +84,17 @@ def predict(
t0
=
time
.
time
()
# 1. 仅仅绑定当前这帧发生变化的图片
!
其他文本输入
早就在显存里躺
好了
。
# 1. 仅仅绑定当前这帧发生变化的图片
,
其他文本输入
绑定
好了
img_tensor
=
np
.
expand_dims
(
np
.
asarray
(
image
),
axis
=
0
)
# 尝试输入进行fp16转换,导出onnx时输入转换为fp16,但是推理性能下降了
# img_tensor = np.expand_dims(np.asarray(image), axis=0).astype(np.float16)
io_binding
.
bind_cpu_input
(
'img'
,
img_tensor
)
# 2. 绑定需要获取的输出
io_binding
.
bind_output
(
'logits'
)
io_binding
.
bind_output
(
'boxes'
)
#
3
.
极速
执行推理
#
2
. 执行推理
ort_session
.
run_with_iobinding
(
io_binding
)
ort_outputs
=
io_binding
.
copy_outputs_to_cpu
()
#
清空输出绑定,否则下一次循环会内存泄漏报错
io_binding
.
clear_binding
_outputs
()
#
3. 结果从GPU 复制回 CPU
ort_outputs
=
io_binding
.
copy
_outputs
_to_cpu
()
infer_time
=
time
.
time
()
-
t0
if
not
is_benchmark
:
...
...
@@ -204,7 +218,7 @@ def benchmark_performance(
if
__name__
==
'__main__'
:
# 配置参数
model_path
=
'../weights/ground_deform_
fp16
.onnx'
model_path
=
'../weights
_opt
/ground_deform_
opt_fp16_all
.onnx'
img_path
=
'../images/in/car_1.jpg'
TEXT_PROMPT
=
"car ."
BOX_TRESHOLD
=
0.35
...
...
@@ -241,6 +255,9 @@ if __name__ == '__main__':
for
key
in
static_keys
:
io_binding
.
bind_cpu_input
(
key
,
TEXT_CACHE
[
key
])
io_binding
.
bind_output
(
'logits'
)
io_binding
.
bind_output
(
'boxes'
)
# 第一步:运行完整的性能测试(预热+实际推理)
performance_result
=
benchmark_performance
(
ort_session
,
io_binding
,
image
,
TEXT_CACHE
,
...
...
@@ -281,6 +298,6 @@ if __name__ == '__main__':
)
# 保存结果
cv2
.
imwrite
(
'./images/out/result.jpg'
,
ori_img
)
print
(
f
"
\n
✅ 结果已保存至: ./images/out/result.jpg"
)
cv2
.
imwrite
(
'.
.
/images/out/result.jpg'
,
ori_img
)
print
(
f
"
\n
✅ 结果已保存至:
.
./images/out/result.jpg"
)
print
(
f
"✅ 检测到目标:
{
phrases
}
(共
{
len
(
boxes
)
}
个)"
)
deform_ort/onnx_optimize.py
View file @
a1865640
...
...
@@ -2,25 +2,26 @@ import onnx
from
onnxsim
import
simplify
from
onnxconverter_common
import
float16
onnx_model_path
=
"../weights_400x600/ground_deform.onnx"
sim_model_path
=
"../weights_400x600/ground_deform_sim.onnx"
fp16_model_path
=
"../weights_400x600/ground_deform_fp16.onnx"
fp16_all_model_path
=
"../weights_400x600/ground_deform_fp16_all.onnx"
onnx_model_path
=
"../weights/ground_deform.onnx"
sim_model_path
=
"../weights_opt/ground_deform_opt.onnx"
fp16_model_path
=
"../weights_opt/ground_deform_opt_fp16.onnx"
fp16_all_model_path
=
"../weights_opt/ground_deform_opt_fp16_all.onnx"
custom_op_lib_path
=
"../ort_plugin_fp16/build/libms_deform_attn_ort.so"
# ==========================================
# 第一步:ONNX Simplify (附带自定义算子库)
# ==========================================
print
(
"1️⃣ 正在进行 ONNX Simplify..."
)
model
=
onnx
.
load
(
onnx_model_path
)
model_simp
,
check
=
simplify
(
model
,
custom_lib
=
custom_op_lib_path
)
#
# ==========================================
#
# 第一步:ONNX Simplify (附带自定义算子库)
#
# ==========================================
#
print("1️⃣ 正在进行 ONNX Simplify...")
#
model = onnx.load(onnx_model_path)
#
model_simp, check = simplify(model, custom_lib=custom_op_lib_path)
if
check
:
onnx
.
save
(
model_simp
,
sim_model_path
)
print
(
f
"✅ Simplify 完成!已保存至
{
sim_model_path
}
"
)
else
:
print
(
"❌ Simplify 验证失败!"
)
exit
()
#
if check:
#
onnx.save(model_simp, sim_model_path)
#
print(f"✅ Simplify 完成!已保存至 {sim_model_path}")
#
else:
#
print("❌ Simplify 验证失败!")
#
exit()
...
...
@@ -30,30 +31,28 @@ else:
# 重新加载 sim 后的模型
model_to_fp16
=
onnx
.
load
(
sim_model_path
)
print
(
"
\n
2️⃣ 正在进行 FP16 混合精度转换..."
)
original_cast_nodes
=
[
node
.
name
for
node
in
model_to_fp16
.
graph
.
node
if
node
.
op_type
==
"Cast"
]
print
(
f
"🔍 查找到
{
len
(
original_cast_nodes
)
}
个原生 Cast 节点,已全部加入保护名单。"
)
print
(
"
\n
2️⃣ 正在进行 FP16 混合精度转换..."
)
model_fp16
=
float16
.
convert_float_to_float16
(
model_to_fp16
,
op_block_list
=
[
"ms_deform_attn"
],
# 屏蔽自定义的注意力算子, 如果是fp32版本自定义算子
node_block_list
=
original_cast_nodes
,
# 保护所有原生的 Cast 节点
keep_io_types
=
True
# 保持整个模型的总输入/输出还是 FP32
)
onnx
.
save
(
model_fp16
,
fp16_model_path
)
print
(
f
"✅ FP16 转换完成(避开自定义算子)!已保存至
{
fp16_model_path
}
"
)
print
(
"
\n
2️⃣ 正在进行纯 FP16 精度转换..."
)
print
(
"
\n
2️⃣ 正在进行纯 FP16 精度转换..."
)
model_fp16_all
=
float16
.
convert_float_to_float16
(
model_to_fp16
,
node_block_list
=
original_cast_nodes
,
# 保护所有原生的 Cast 节点
keep_io_types
=
True
# 保持整个模型的总输入/输出还是 FP32
)
onnx
.
save
(
model_fp16_all
,
fp16_all_model_path
)
print
(
f
"✅ FP16 转换完成!已保存至
{
fp16_all_model_path
}
"
)
print
(
f
"✅
纯
FP16 转换完成!已保存至
{
fp16_all_model_path
}
"
)
deform_ort/profile_analyzer.py
0 → 100644
View file @
a1865640
import
json
import
sys
from
collections
import
defaultdict
def
analyze_profile
(
json_path
):
print
(
f
"🔍 正在解析性能文件:
{
json_path
}
\n
"
)
with
open
(
json_path
,
'r'
)
as
f
:
data
=
json
.
load
(
f
)
# 兼容不同的 JSON 根节点格式
events
=
data
if
isinstance
(
data
,
list
)
else
data
.
get
(
'traceEvents'
,
[])
# 按“算子类型”(如 MatMul, Conv) 统计总耗时
op_type_times
=
defaultdict
(
float
)
# 按“具体节点名”(如 /transformer/encoder/MatMul_1) 统计总耗时
node_name_times
=
defaultdict
(
float
)
total_inference_time
=
0.0
for
event
in
events
:
# 只统计包含持续时间(dur)和参数(args)的事件
if
'dur'
in
event
and
'args'
in
event
:
args
=
event
[
'args'
]
# ORT 通常把算子类型记录在 args 里的 op_name
if
'op_name'
in
args
:
op_type
=
args
[
'op_name'
]
# event['name'] 通常包含完整的节点路径
node_name
=
event
.
get
(
'name'
,
'Unknown_Node'
)
# JSON 里的 dur 单位是微秒 (microseconds),转成毫秒 (ms)
dur_ms
=
event
[
'dur'
]
/
1000.0
op_type_times
[
op_type
]
+=
dur_ms
node_name_times
[
node_name
]
+=
dur_ms
total_inference_time
+=
dur_ms
# 对字典进行降序排序
sorted_op_types
=
sorted
(
op_type_times
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
sorted_nodes
=
sorted
(
node_name_times
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
print
(
"="
*
50
)
print
(
"🏆 按【算子类型 (OpType)】耗时总和排名 Top 10"
)
print
(
"="
*
50
)
for
i
,
(
op
,
time_ms
)
in
enumerate
(
sorted_op_types
[:
10
]):
percentage
=
(
time_ms
/
total_inference_time
)
*
100
if
total_inference_time
>
0
else
0
print
(
f
"
{
i
+
1
:
2
d
}
.
{
op
:
<
20
}
| 耗时:
{
time_ms
:
>
8.3
f
}
ms | 占比:
{
percentage
:
>
5.2
f
}
%"
)
print
(
"
\n
"
+
"="
*
50
)
print
(
"🎯 按【单个具体节点 (Node)】耗时排名 Top 15"
)
print
(
"="
*
50
)
for
i
,
(
node
,
time_ms
)
in
enumerate
(
sorted_nodes
[:
15
]):
percentage
=
(
time_ms
/
total_inference_time
)
*
100
if
total_inference_time
>
0
else
0
print
(
f
"
{
i
+
1
:
2
d
}
. 耗时:
{
time_ms
:
>
8.3
f
}
ms (
{
percentage
:
>
5.2
f
}
%) | 节点:
{
node
}
"
)
if
__name__
==
"__main__"
:
# 把这里换成你刚刚生成的 json 文件名
profile_file
=
"./onnxruntime_profile__2026-04-27_13-58-17.json"
if
len
(
sys
.
argv
)
>
1
:
profile_file
=
sys
.
argv
[
1
]
analyze_profile
(
profile_file
)
\ No newline at end of file
deform_ort/result.jpg
View replaced file @
0896d47e
View file @
a1865640
1.35 MB
|
W:
|
H:
1.35 MB
|
W:
|
H:
2-up
Swipe
Onion skin
groundingdino/util/inference.py
View file @
a1865640
...
...
@@ -39,8 +39,8 @@ def load_model(model_config_path: str, model_checkpoint_path: str, device: str =
def
load_image
(
image_path
:
str
)
->
Tuple
[
np
.
array
,
torch
.
Tensor
]:
transform
=
T
.
Compose
(
[
#
T.RandomResize([800], max_size=1333),
T
.
RandomResize
([
400
],
max_size
=
1333
),
T
.
RandomResize
([
800
],
max_size
=
1333
),
#
T.RandomResize([400], max_size=1333),
T
.
ToTensor
(),
T
.
Normalize
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
]),
]
...
...
images/out/result.jpg
View replaced file @
0896d47e
View file @
a1865640
1.35 MB
|
W:
|
H:
1.35 MB
|
W:
|
H:
2-up
Swipe
Onion skin
migraphx_infer/migraphx_export.bash
0 → 100644
View file @
a1865640
export
MIGRAPHX_ENABLE_MIOPEN_CONCAT
=
1
migraphx-driver perf
--onnx
\
../weights/ground_opt.onnx
\
--fp16
\
--output
\
../weights/ground_opt.mxr
\ No newline at end of file
migraphx_infer/migraphx_infer.bash
deleted
100644 → 0
View file @
0896d47e
MIGRAPHX_LOG
=
debug migraphx-driver compile
\
--onnx
weights/ground_external.onnx
\
--gpu
\
-p
dead_code_elimination
\
--output
weights/ground.mgx
# -p eliminate_contiguous \
# -p simplify_reshapes \
# -p simplify_algebra \
# -p eliminate_identity \
# -p common_subexpression_elimination \
\ No newline at end of file
migraphx_infer/migraphx_infer.py
View file @
a1865640
...
...
@@ -57,7 +57,7 @@ def _mgx_shape_to_numpy(shape):
# 🚀 MIGraphX 推理类(带缓存)
# =========================
class
MIGraphXModel
:
def
__init__
(
self
,
onnx_path
,
cache_path
=
"weights/ground.mxr"
,
force_recompile
=
False
):
def
__init__
(
self
,
onnx_path
,
cache_path
=
"weights/ground
_opt
.mxr"
,
force_recompile
=
False
):
self
.
cache_path
=
cache_path
# ====== 优先加载缓存 ======
...
...
@@ -228,10 +228,10 @@ def benchmark(model, tokenizer, image, caption, box_th, text_th, warmup=5, runs=
# =========================
if
__name__
==
"__main__"
:
model_path
=
"weights/ground_
simplified
.onnx"
cache_path
=
"weights/ground_
simplified
.mxr"
# ⭐ 缓存文件
model_path
=
"
../
weights/ground_
opt
.onnx"
cache_path
=
"
../
weights/ground_
opt
.mxr"
# ⭐ 缓存文件
img_path
=
"images/in/car_1.jpg"
img_path
=
"
../
images/in/car_1.jpg"
TEXT_PROMPT
=
"car ."
BOX_TRESHOLD
=
0.35
...
...
migraphx_infer/migraphx_infer1.py
0 → 100644
View file @
a1865640
import
cv2
import
numpy
as
np
import
time
import
os
import
migraphx
from
typing
import
Tuple
import
torch
import
groundingdino.datasets.transforms
as
T
from
PIL
import
Image
def
load_image
(
image_path
:
str
)
->
Tuple
[
np
.
array
,
torch
.
Tensor
]:
transform
=
T
.
Compose
(
[
T
.
RandomResize
([
800
],
max_size
=
1333
),
# T.RandomResize([400], max_size=1333),
T
.
ToTensor
(),
T
.
Normalize
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
]),
]
)
image_source
=
Image
.
open
(
image_path
).
convert
(
"RGB"
)
image
=
np
.
asarray
(
image_source
)
image_transformed
,
_
=
transform
(
image_source
,
None
)
return
image
,
image_transformed
def
sigmoid
(
x
):
return
1
/
(
1
+
np
.
exp
(
-
x
))
def
_mgx_shape_to_numpy
(
shape
):
shape_str
=
str
(
shape
)
if
"int64_type"
in
shape_str
:
dtype
=
np
.
int64
elif
"bool_type"
in
shape_str
:
dtype
=
np
.
bool_
elif
"half_type"
in
shape_str
:
dtype
=
np
.
float16
else
:
dtype
=
np
.
float32
try
:
dims
=
list
(
shape
.
dims
())
except
Exception
:
dims
=
[]
try
:
lens
=
list
(
shape
.
lens
())
except
Exception
:
lens
=
[]
return
dtype
,
(
dims
if
len
(
dims
)
>
0
else
lens
)
# =========================
# 🚀 MIGraphX 推理类(带缓存与生命周期管理)
# =========================
class
MIGraphXModel
:
def
__init__
(
self
,
onnx_path
,
cache_path
=
"weights/ground_opt.mxr"
,
force_recompile
=
False
,
device_id
=
0
):
self
.
cache_path
=
cache_path
if
os
.
path
.
exists
(
cache_path
)
and
not
force_recompile
:
print
(
f
"⚡ 直接加载已编译模型:
{
cache_path
}
"
)
self
.
model
=
migraphx
.
load
(
cache_path
)
else
:
print
(
"🔍 从 ONNX 构建 MIGraphX"
)
self
.
model
=
migraphx
.
parse_onnx
(
onnx_path
)
print
(
f
"⚙️ 编译 MIGraphX(GPU
{
device_id
}
)"
)
self
.
model
.
compile
(
t
=
migraphx
.
get_target
(
"gpu"
),
device_id
=
device_id
)
print
(
f
"💾 保存编译模型到:
{
cache_path
}
"
)
migraphx
.
save
(
self
.
model
,
cache_path
)
self
.
input_shapes
=
self
.
model
.
get_inputs
()
def
infer
(
self
,
input_dict
):
mgx_inputs
=
{}
# 【关键修复区】:用于保持 NumPy 数组存活,防止 Python 垃圾回收导致底层指针失效
self
.
_keep_alive_cache
=
{}
provided_names
=
set
(
input_dict
.
keys
())
required_names
=
{
k
for
k
in
self
.
input_shapes
.
keys
()
if
not
str
(
k
).
startswith
(
"main:#output"
)
}
for
name
in
required_names
:
shape
=
self
.
input_shapes
[
name
]
target_dtype
,
lens
=
_mgx_shape_to_numpy
(
shape
)
if
name
in
provided_names
:
# 1. 必须转为连续内存!防止 PyTorch 转过来的 array 内存步长不一致
arr
=
np
.
ascontiguousarray
(
input_dict
[
name
])
# 2. 强制类型转换
if
arr
.
dtype
!=
target_dtype
:
arr
=
arr
.
astype
(
target_dtype
)
else
:
# 缺失的输入用 0 补齐
arr
=
np
.
zeros
(
lens
,
dtype
=
target_dtype
)
# 3. 将数组塞进字典,强行续命!
self
.
_keep_alive_cache
[
name
]
=
arr
# 4. 安全地将指针移交给 migraphx
mgx_inputs
[
name
]
=
migraphx
.
argument
(
arr
)
start
=
time
.
time
()
result
=
self
.
model
.
run
(
mgx_inputs
)
infer_time
=
time
.
time
()
-
start
outputs
=
[
np
.
array
(
r
)
for
r
in
result
]
# 推理结束,释放内存
self
.
_keep_alive_cache
.
clear
()
return
outputs
,
infer_time
# =========================
# 推理函数 (硬编码输入,无 Tokenizer)
# =========================
def
predict
(
model
,
image
,
box_threshold
,
is_benchmark
=
False
):
input_dict
=
{
"img"
:
np
.
expand_dims
(
np
.
asarray
(
image
),
axis
=
0
),
"position_ids"
:
np
.
array
([[
0
,
0
,
1
,
0
]]),
"input_ids"
:
np
.
array
([[
101
,
2482
,
1012
,
102
]]),
"token_type_ids"
:
np
.
array
([[
0
,
0
,
0
,
0
]]),
"text_token_mask"
:
np
.
array
([[
[
True
,
False
,
False
,
False
],
[
False
,
True
,
True
,
False
],
[
False
,
True
,
True
,
False
],
[
False
,
False
,
False
,
True
]
]]),
"attention_mask"
:
np
.
array
([[
True
,
True
,
True
,
True
]])
}
outputs
,
infer_time
=
model
.
infer
(
input_dict
)
if
not
is_benchmark
:
print
(
f
"Inference time:
{
infer_time
*
1000
:.
2
f
}
ms"
)
logits
=
sigmoid
(
outputs
[
0
][
0
])
boxes
=
outputs
[
1
][
0
]
max_values
=
np
.
max
(
logits
,
axis
=
1
)
mask
=
max_values
>
box_threshold
logits
=
logits
[
mask
]
boxes
=
boxes
[
mask
]
phrases
=
[
"car"
]
*
len
(
boxes
)
return
boxes
,
np
.
max
(
logits
,
axis
=
1
),
phrases
# =========================
# Benchmark
# =========================
def
benchmark
(
model
,
image
,
box_th
,
warmup
=
5
,
runs
=
10
):
print
(
"
\n
🔥 预热"
)
for
_
in
range
(
warmup
):
predict
(
model
,
image
,
box_th
,
True
)
print
(
"
\n
🚀 测试"
)
times
=
[]
for
i
in
range
(
runs
):
start
=
time
.
time
()
predict
(
model
,
image
,
box_th
,
True
)
times
.
append
(
time
.
time
()
-
start
)
print
(
f
"
\n
平均耗时:
{
np
.
mean
(
times
)
*
1000
:.
2
f
}
ms"
)
print
(
f
"FPS:
{
1
/
np
.
mean
(
times
):.
2
f
}
"
)
# =========================
# 主函数
# =========================
# if __name__ == "__main__":
# model_path = "../weights/ground_opt.onnx"
# cache_path = "../weights/ground_opt.mxr"
# img_path = "../images/in/car_1.jpg"
# BOX_TRESHOLD = 0.35
# DEVICE_ID = 5 # 匹配你之前报错堆栈里的 device: 5 / 0 的情况,按需修改
# model = MIGraphXModel(
# model_path,
# cache_path=cache_path,
# force_recompile=False,
# device_id=DEVICE_ID
# )
# image_source, image = load_image(img_path)
# benchmark(model, image, BOX_TRESHOLD)
# boxes, confs, phrases = predict(model, image, BOX_TRESHOLD)
# print("检测结果:", phrases)
def
test_like_perf
(
model
):
print
(
"
\n
"
+
"="
*
60
)
print
(
"🛠️ 模拟 perf 工具:生成完美对齐的 Dummy 数据测试"
)
print
(
"="
*
60
)
mgx_inputs
=
{}
keep_alive_cache
=
[]
# 强行续命池
# 1. 严格按照模型要求的形状造假数据
for
name
,
shape
in
model
.
get_inputs
().
items
():
if
str
(
name
).
startswith
(
"main:#output"
):
continue
# 解析真实需要的类型和形状
target_dtype
,
lens
=
_mgx_shape_to_numpy
(
shape
)
print
(
f
" 📦 分配
{
name
}
: shape=
{
lens
}
, dtype=
{
target_dtype
.
__name__
}
"
)
# 生成分毫不差的全零矩阵(完美模拟 migraphx-driver)
dummy_data
=
np
.
zeros
(
lens
,
dtype
=
target_dtype
)
keep_alive_cache
.
append
(
dummy_data
)
# 移交指针
mgx_inputs
[
name
]
=
migraphx
.
argument
(
dummy_data
)
print
(
"
\n
🚀 开始 Dummy 推理测试..."
)
try
:
start
=
time
.
time
()
model
.
run
(
mgx_inputs
)
print
(
f
"✅ Python 端 Dummy 推理成功!没有任何 VMFault!耗时:
{
(
time
.
time
()
-
start
)
*
1000
:.
2
f
}
ms"
)
except
Exception
as
e
:
print
(
f
"❌ 依然报错:
{
e
}
"
)
# ------------------
# 在主函数里这样调用:
# ------------------
if
__name__
==
"__main__"
:
model_path
=
"../weights/ground_opt.onnx"
cache_path
=
"../weights/ground_opt.mxr"
model
=
migraphx
.
load
(
cache_path
)
# 直接加载你确定没问题的 mxr
# 运行模拟测试
test_like_perf
(
model
)
\ No newline at end of file
migraphx_infer/migraphx_perf.bash
0 → 100644
View file @
a1865640
migraphx-driver perf
--batch
1
\
-n
10
\
--fp16
\
--migraphx
../weights/ground_opt.mxr
\ No newline at end of file
migraphx_infer/modify_onnx.py
0 → 100644
View file @
a1865640
import
sys
import
numpy
as
np
from
onnx_modifier
import
ONNXModifier
def
change_inf_to_value
(
om
:
ONNXModifier
):
records
=
set
()
for
where_node
in
om
.
get_nodes
(
"Where"
):
for
input_name
in
where_node
.
inputs
[
1
:]:
init
=
om
.
get_initializer
(
input_name
)
if
init
is
None
:
continue
assert
input_name
==
init
.
name
init_name
=
input_name
if
init_name
in
records
:
continue
# info = np.finfo(np.float32)
info
=
np
.
finfo
(
np
.
float16
)
data
=
om
.
get_initializer_value
(
init
.
name
)
if
data
.
size
>
1
:
continue
if
data
==
np
.
inf
:
om
.
set_initializer_value
(
init_name
,
np
.
array
(
info
.
max
,
dtype
=
np
.
float32
))
elif
data
==
-
np
.
inf
:
om
.
set_initializer_value
(
init_name
,
np
.
array
(
info
.
min
,
dtype
=
np
.
float32
))
else
:
continue
# print("Changed value:", init_name)
records
.
add
(
init_name
)
def
optimize_where_ndoes
(
om
:
ONNXModifier
):
"""Where节点等价替换
(1) condition为initializer, X为0, Y为输入数据:
Where(cond, X, Y) ==> Mul(Y, ~cond)
(2) condition为initializer, X为负无穷, Y为输入数据
Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
(3) condition为真实输入, X为负无穷, Y为输入数据
Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
cases:
1. Where(cond, -inf, input)
a. /transformer/encoder/fusion_layers.*/attn/Where
b. /transformer/encoder/fusion_layers.*/attn/Where_1
c. /class_embed.0_*/Where: Where(cond, -inf, input)
2. Where(cond, 0, input):
a. /transformer/encoder/layers.*/self_attn/Where
b. /transformer/decoder/layers.*/cross_attn/Where
"""
for
where_node
in
om
.
get_nodes
(
"Where"
):
where_name
=
where_node
.
name
# print("Process where node:", where_name)
x_value
=
om
.
get_initializer_value
(
where_node
.
inputs
[
1
])
assert
x_value
.
size
==
1
assert
x_value
==
np
.
array
(
0.0
,
dtype
=
np
.
float32
)
or
\
x_value
==
np
.
array
(
-
np
.
inf
,
dtype
=
np
.
float32
)
cond_init
=
om
.
get_initializer
(
where_node
.
inputs
[
0
])
if
cond_init
is
not
None
:
cond_value
=
om
.
get_initializer_value
(
where_node
.
inputs
[
0
])
if
x_value
==
np
.
array
(
0.0
,
dtype
=
np
.
float32
):
# Where(cond, X, Y) ==> Mul(Y, ~cond)
mul_name
=
where_name
.
replace
(
"Where"
,
"NewMul"
)
mul_b_init
=
om
.
create_initializer
(
mul_name
+
"_B"
,
(
~
cond_value
).
astype
(
np
.
float32
))
mul_node
=
om
.
create_node
(
"Mul"
,
mul_name
,
[
where_node
.
inputs
[
2
],
mul_b_init
.
name
],
[
mul_name
+
"_output_0"
],
index
=
where_node
.
index
)
next_nodes
=
where_node
.
next_nodes
for
next_node
in
next_nodes
:
next_node
.
replace_input
(
where_node
.
outputs
[
0
],
mul_node
.
outputs
[
0
])
elif
x_value
==
np
.
array
(
-
np
.
inf
,
dtype
=
np
.
float32
):
# Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
sub_name
=
where_name
.
replace
(
"Where"
,
"NewSub"
)
sub_b_init
=
om
.
create_initializer
(
sub_name
+
"_B"
,
np
.
where
(
cond_value
.
astype
(
np
.
float32
),
np
.
finfo
(
np
.
float16
).
max
,
0.0
).
astype
(
np
.
float32
)
)
sub_node
=
om
.
create_node
(
"Sub"
,
sub_name
,
[
where_node
.
inputs
[
2
],
sub_b_init
.
name
],
[
sub_name
+
"_output_0"
],
index
=
where_node
.
index
)
next_nodes
=
where_node
.
next_nodes
for
next_node
in
next_nodes
:
next_node
.
replace_input
(
where_node
.
outputs
[
0
],
sub_node
.
outputs
[
0
])
else
:
# Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
assert
x_value
==
np
.
array
(
-
np
.
inf
,
dtype
=
np
.
float32
)
cast_name
=
where_name
.
replace
(
"Where"
,
"NewCast"
)
mul_name
=
where_name
.
replace
(
"Where"
,
"NewMul"
)
sub_name
=
where_name
.
replace
(
"Where"
,
"NewSub"
)
cast_node
=
om
.
create_node
(
"Cast"
,
cast_name
,
[
where_node
.
inputs
[
0
]],
[
cast_name
+
"_output_0"
],
to
=
1
,
index
=
where_node
.
index
)
mul_b_init
=
om
.
create_initializer
(
mul_name
+
"_B"
,
np
.
array
([
np
.
finfo
(
np
.
float16
).
max
],
np
.
float32
))
mul_node
=
om
.
create_node
(
"Mul"
,
mul_name
,
[
cast_node
.
outputs
[
0
],
mul_b_init
.
name
],
[
mul_name
+
"_output_0"
],
index
=
cast_node
.
index
+
1
)
sub_node
=
om
.
create_node
(
"Sub"
,
sub_name
,
[
where_node
.
inputs
[
2
],
mul_node
.
outputs
[
0
]],
[
sub_name
+
"_output_0"
],
index
=
mul_node
.
index
+
1
)
next_nodes
=
where_node
.
next_nodes
for
next_node
in
next_nodes
:
next_node
.
replace_input
(
where_node
.
outputs
[
0
],
sub_node
.
outputs
[
0
])
om
.
update_map
()
def
optimize_transpose_nodes
(
om
:
ONNXModifier
):
transpose_list
=
[
"/transformer/encoder/Transpose"
,
"/transformer/encoder/Transpose_1"
,
"/transformer/encoder/Transpose_2"
,
"/transformer/encoder/Transpose_3"
,
"/transformer/encoder/Transpose_4"
,
"/transformer/encoder/Transpose_5"
,
"/transformer/encoder/Transpose_6"
,
"/transformer/encoder/Transpose_7"
,
"/transformer/encoder/Transpose_8"
,
"/transformer/encoder/Transpose_9"
,
"/transformer/encoder/Transpose_10"
,
"/transformer/encoder/Transpose_11"
,
"/transformer/decoder/layers.0/Transpose"
,
"/transformer/decoder/layers.0/Transpose_1"
,
"/transformer/decoder/layers.0/Transpose_2"
,
"/transformer/decoder/layers.1/Transpose"
,
"/transformer/decoder/layers.1/Transpose_1"
,
"/transformer/decoder/layers.1/Transpose_2"
,
"/transformer/decoder/layers.2/Transpose"
,
"/transformer/decoder/layers.2/Transpose_1"
,
"/transformer/decoder/layers.2/Transpose_2"
,
"/transformer/decoder/layers.3/Transpose"
,
"/transformer/decoder/layers.3/Transpose_1"
,
"/transformer/decoder/layers.3/Transpose_2"
,
"/transformer/decoder/layers.4/Transpose"
,
"/transformer/decoder/layers.4/Transpose_1"
,
"/transformer/decoder/layers.4/Transpose_2"
,
"/transformer/decoder/layers.5/Transpose"
,
"/transformer/decoder/layers.5/Transpose_1"
,
"/transformer/decoder/layers.5/Transpose_2"
,
"/transformer/Transpose_8"
,
"/transformer/decoder/Transpose"
,
"/transformer/decoder/Transpose_1"
,
"/transformer/decoder/Transpose_2"
,
"/transformer/decoder/Transpose_3"
,
"/transformer/decoder/Transpose_4"
,
"/transformer/decoder/Transpose_5"
,
"/transformer/decoder/Transpose_6"
,
"/transformer/decoder/Transpose_7"
,
"/transformer/decoder/Transpose_8"
,
"/transformer/decoder/Transpose_9"
,
"/transformer/decoder/Transpose_10"
,
"/transformer/decoder/Transpose_11"
]
for
name
in
transpose_list
:
node
=
om
.
get_node
(
name
)
assert
node
.
attrs
[
'perm'
]
==
[
1
,
0
,
2
]
or
node
.
attrs
[
'perm'
]
==
[
1
,
0
,
2
,
3
],
\
f
"perm=
{
node
.
attrs
[
'perm'
]
}
"
next_nodes
=
om
.
get_next_nodes
(
node
)
for
node_
in
next_nodes
:
node_
.
replace_input
(
node
.
outputs
[
0
],
node
.
inputs
[
0
])
# modify /transformer/encoder/text_layers.*/self_attn/Reshape_4
# om.set_initializer_value("_v_8735", np.array([-1, 4, 256], np.int64))
shape_init1
=
om
.
create_initializer
(
"/transformer/encoder/text_layers.x/self_attn/des_shape"
,
np
.
array
([
1
,
4
,
256
],
np
.
int64
)
)
for
i
in
range
(
6
):
reshape_node
=
om
.
get_node
(
f
"/transformer/encoder/text_layers.
{
i
}
/self_attn/Reshape_4"
)
reshape_node
.
set_input
(
1
,
shape_init1
.
name
)
# modify /transformer/enc_out_class_embed/Transpose
om
.
get_node
(
"/transformer/enc_out_class_embed/Transpose"
).
set_attribute
(
"perm"
,
[
0
,
2
,
1
])
# modify /transformer/decoder/Reshape_*
om
.
set_initializer_value
(
"_v_5525"
,
np
.
array
([
1
,
900
,
-
1
],
np
.
int64
))
# modify /transformer/decoder/layers.*/self_attn/Reshape_4
# modify /transformer/decoder/layers.*/ca_text/Reshape_6
# om.set_initializer_value("_v_6230", np.array([-1, 900, 256], np.int64))
shape_init3
=
om
.
create_initializer
(
"/transformer/decoder/layers.x/self_attn_ca_text/des_shape"
,
np
.
array
([
1
,
900
,
256
],
np
.
int64
)
)
for
i
in
range
(
6
):
reshape_node1
=
om
.
get_node
(
f
"/transformer/decoder/layers.
{
i
}
/self_attn/Reshape_4"
)
reshape_node1
.
set_input
(
1
,
shape_init3
.
name
)
reshape_node2
=
om
.
get_node
(
f
"/transformer/decoder/layers.
{
i
}
/ca_text/Reshape_6"
)
reshape_node2
.
set_input
(
1
,
shape_init3
.
name
)
# modify /transformer/decoder/layers.0/Add
# modify /transformer/decoder/layers.0/Add_1
init_name
=
"/transformer/Tile_1_output_0"
add_value
=
om
.
get_initializer_value
(
init_name
)
om
.
set_initializer_value
(
init_name
,
np
.
ascontiguousarray
(
add_value
.
transpose
(
1
,
0
,
2
)))
om
.
update_map
()
om
.
infer_shape
()
def
optmize_sin_cos_block
(
om
:
ONNXModifier
):
node_pairs
=
[
(
"/transformer/decoder/Gather_1"
,
"/transformer/decoder/ref_point_head/layers.0/MatMul"
),
(
"/transformer/decoder/Gather_6"
,
"/transformer/decoder/ref_point_head/layers.0_1/MatMul"
),
(
"/transformer/decoder/Gather_11"
,
"/transformer/decoder/ref_point_head/layers.0_2/MatMul"
),
(
"/transformer/decoder/Gather_16"
,
"/transformer/decoder/ref_point_head/layers.0_3/MatMul"
),
(
"/transformer/decoder/Gather_21"
,
"/transformer/decoder/ref_point_head/layers.0_4/MatMul"
),
(
"/transformer/decoder/Gather_26"
,
"/transformer/decoder/ref_point_head/layers.0_5/MatMul"
),
]
unsqueeze_axes_init1
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/unsqueeze_axes1"
,
np
.
array
([
3
,
4
],
np
.
int64
)
)
slice_axes_init
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_axes"
,
np
.
array
([
4
],
np
.
int64
)
)
slice_steps_init
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_steps"
,
np
.
array
([
1
],
np
.
int64
)
)
slice_starts_init1
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_starts1"
,
np
.
array
([
0
],
np
.
int64
)
)
slice_ends_init1
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_ends1"
,
np
.
array
([
1
],
np
.
int64
)
)
slice_starts_init2
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_steps2"
,
np
.
array
([
1
],
np
.
int64
)
)
slice_ends_init2
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_ends2"
,
np
.
array
([
2
],
np
.
int64
)
)
reshape_init
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/reshape_dst_shape"
,
np
.
array
([
1
,
900
,
-
1
],
np
.
int64
)
)
for
i
,
(
gather_name
,
matmul_name
)
in
enumerate
(
node_pairs
):
gather_node
=
om
.
get_node
(
gather_name
)
next_node
=
om
.
get_next_nodes
(
gather_node
)[
0
]
assert
next_node
.
op_type
==
"Mul"
,
f
"
{
next_node
.
op_type
}
{
next_node
.
name
}
"
mul_init_value
=
om
.
get_initializer_value
(
next_node
.
inputs
[
1
])
assert
mul_init_value
.
size
==
1
next_node
=
om
.
get_next_nodes
(
next_node
)[
0
]
assert
next_node
.
op_type
==
"Unsqueeze"
next_node
.
set_inputs
([
gather_node
.
inputs
[
0
],
unsqueeze_axes_init1
.
name
])
next_node
=
om
.
get_next_nodes
(
next_node
)[
0
]
assert
next_node
.
op_type
==
"Div"
div_init_value
=
om
.
get_initializer_value
(
next_node
.
inputs
[
1
])
new_value
=
(
div_init_value
/
mul_init_value
).
reshape
(
1
,
1
,
1
,
64
,
2
)
new_init
=
om
.
create_initializer
(
next_node
.
name
+
"_B"
,
new_value
)
next_node
.
set_input
(
1
,
new_init
.
name
)
next_nodes
=
om
.
get_next_nodes
(
next_node
)
assert
len
(
next_nodes
)
==
2
and
all
(
x
.
op_type
==
'Slice'
for
x
in
next_nodes
)
sin_node
,
cos_node
=
None
,
None
for
j
,
slice_node
in
enumerate
(
next_nodes
):
slice_node
.
set_inputs
([
slice_node
.
inputs
[
0
],
slice_starts_init1
.
name
if
j
==
0
else
slice_starts_init2
.
name
,
slice_ends_init1
.
name
if
j
==
0
else
slice_ends_init2
.
name
,
slice_axes_init
.
name
,
slice_steps_init
.
name
])
next_node
=
om
.
get_next_nodes
(
slice_node
)[
0
]
if
next_node
.
op_type
==
"Sin"
:
sin_node
=
next_node
elif
next_node
.
op_type
==
"Cos"
:
cos_node
=
next_node
else
:
raise
RuntimeError
(
"match fail!"
)
next_node
=
om
.
get_next_nodes
(
next_node
)[
0
]
assert
next_node
.
op_type
==
"Unsqueeze"
next_node
=
om
.
get_next_nodes
(
next_node
)[
0
]
assert
next_node
.
op_type
==
"Concat"
next_node
.
set_inputs
([
sin_node
.
outputs
[
0
],
cos_node
.
outputs
[
0
]])
next_node
.
set_attribute
(
"axis"
,
4
)
next_node
=
om
.
get_next_nodes
(
next_node
)[
0
]
assert
next_node
.
op_type
==
"Reshape"
next_node
.
set_input
(
1
,
reshape_init
.
name
)
matmul_node
=
om
.
get_node
(
matmul_name
)
matmul_node
.
set_input
(
0
,
next_node
.
outputs
[
0
])
if
i
==
0
:
mm_b_value
=
om
.
get_initializer_value
(
matmul_node
.
inputs
[
1
])
mm_b_value
=
np
.
concatenate
([
mm_b_value
[
128
:
256
,
...],
mm_b_value
[
0
:
128
,
...],
mm_b_value
[
256
:,
...]],
axis
=
0
)
om
.
set_initializer_value
(
matmul_node
.
inputs
[
1
],
mm_b_value
)
om
.
update_map
()
om
.
infer_shape
()
def
fuse_one_attention
(
om
:
ONNXModifier
,
softmax_name
:
str
,
new_mask
:
bool
=
None
,
num_heads
:
int
=
12
):
softmax_node
=
om
.
get_node
(
softmax_name
)
tmp_node
=
om
.
get_prev_nodes
(
softmax_node
)[
0
]
assert
tmp_node
.
op_type
in
[
"MatMul"
,
"Add"
]
mask
=
None
if
tmp_node
.
op_type
==
"Add"
:
mask_node
=
tmp_node
tmp_node
=
om
.
get_from_node
(
mask_node
.
inputs
[
0
])
if
tmp_node
.
op_type
==
"Div"
:
tmp_node
=
om
.
get_from_node
(
tmp_node
.
inputs
[
0
])
assert
tmp_node
.
op_type
==
"MatMul"
mask
=
mask_node
.
inputs
[
1
]
assert
new_mask
is
not
None
tmp_node1
=
om
.
get_from_node
(
tmp_node
.
inputs
[
0
])
if
tmp_node1
.
op_type
==
"Mul"
:
tmp_node1
=
om
.
get_prev_nodes
(
tmp_node1
)[
0
]
tmp_node2
=
om
.
get_from_node
(
tmp_node
.
inputs
[
1
])
assert
tmp_node1
.
op_type
==
tmp_node2
.
op_type
==
"Transpose"
tmp_node1
=
om
.
get_prev_nodes
(
tmp_node1
)[
0
]
tmp_node2
=
om
.
get_prev_nodes
(
tmp_node2
)[
0
]
assert
tmp_node1
.
op_type
==
tmp_node2
.
op_type
==
"Reshape"
q
,
k
=
tmp_node1
.
inputs
[
0
],
tmp_node2
.
inputs
[
0
]
tmp_node
=
om
.
get_next_nodes
(
softmax_node
)[
0
]
assert
tmp_node
.
op_type
==
"MatMul"
tmp_node3
=
om
.
get_from_node
(
tmp_node
.
inputs
[
1
])
if
tmp_node3
is
not
None
:
assert
tmp_node3
.
op_type
==
"Transpose"
tmp_node3
=
om
.
get_prev_nodes
(
tmp_node3
)[
0
]
assert
tmp_node3
.
op_type
==
"Reshape"
v
=
tmp_node3
.
inputs
[
0
]
else
:
v_init
=
om
.
get_initializer
(
tmp_node
.
inputs
[
1
])
v_init_value
=
om
.
get_initializer_value
(
tmp_node
.
inputs
[
1
])
v_init_value
=
v_init_value
[
None
,
...].
transpose
(
0
,
2
,
1
,
3
)
B
,
S
,
H
,
D
=
v_init_value
.
shape
v_init_value
=
np
.
ascontiguousarray
(
v_init_value
.
reshape
(
B
,
S
,
H
*
D
))
om
.
set_initializer_value
(
tmp_node
.
inputs
[
1
],
v_init_value
)
v
=
v_init
.
name
tmp_node
=
om
.
get_next_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Transpose"
tmp_node
=
om
.
get_next_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Reshape"
mha_next_node
=
om
.
get_next_nodes
(
tmp_node
)[
0
]
if
mha_next_node
.
op_type
==
"Gemm"
:
gemm_next_node
=
om
.
get_next_nodes
(
mha_next_node
)[
0
]
assert
gemm_next_node
.
op_type
==
"Reshape"
reshape_next_node
=
om
.
get_next_nodes
(
gemm_next_node
)[
0
]
assert
reshape_next_node
.
op_type
==
"Add"
else
:
assert
mha_next_node
.
op_type
==
"MatMul"
name_prefix
=
'/'
.
join
(
softmax_name
.
split
(
'/'
)[:
-
1
])
mha_name
=
f
"
{
name_prefix
}
/MultiHeadAttention"
mha_node
=
om
.
create_node
(
"MultiHeadAttention"
,
mha_name
,
[
q
,
k
,
v
]
if
mask
is
None
else
[
q
,
k
,
v
,
new_mask
],
[
mha_name
+
'_output_0'
],
num_heads
=
num_heads
,
domain
=
"com.microsoft"
,
index
=
mha_next_node
.
index
-
1
)
mha_next_node
.
replace_input
(
mha_next_node
.
inputs
[
0
],
mha_node
.
outputs
[
0
])
if
mha_next_node
.
op_type
==
"Gemm"
:
weights
=
om
.
get_initializer_value
(
mha_next_node
.
inputs
[
1
])
transB
=
mha_next_node
.
attrs
[
"transB"
]
assert
transB
==
1
weights
=
np
.
ascontiguousarray
(
weights
.
transpose
(
1
,
0
))
om
.
set_initializer_value
(
mha_next_node
.
inputs
[
1
],
weights
)
new_matmul_name
=
mha_next_node
.
name
.
replace
(
"Gemm"
,
"MatMul(Gemm)"
)
new_matmul_node
=
om
.
create_node
(
"MatMul"
,
new_matmul_name
,
[
mha_node
.
outputs
[
0
],
mha_next_node
.
inputs
[
1
]],
[
new_matmul_name
+
"_output_0"
],
index
=
mha_next_node
.
index
)
new_bias_name
=
mha_next_node
.
name
.
replace
(
"Gemm"
,
"Add(Gemm)"
)
new_add_node
=
om
.
create_node
(
"Add"
,
new_bias_name
,
[
new_matmul_node
.
outputs
[
0
],
mha_next_node
.
inputs
[
2
]],
[
new_bias_name
+
"_output_0"
],
index
=
new_matmul_node
.
index
+
1
)
reshape_next_node
.
replace_input
(
gemm_next_node
.
outputs
[
0
],
new_add_node
.
outputs
[
0
])
def
optimize_normal_attention
(
om
:
ONNXModifier
):
def
create_new_attention_mask
():
mask_next_node
=
om
.
get_to_nodes
(
"attention_mask"
)[
0
]
cast_node
=
om
.
create_node
(
"Cast"
,
"Cast_for_attention_mask"
,
[
"attention_mask"
],
[
"Cast_for_attention_mask_output_0"
],
to
=
1
,
index
=
mask_next_node
.
index
)
reducesum_node
=
om
.
create_node
(
"ReduceSum"
,
"ReduceSum_for_mask"
,
[
cast_node
.
outputs
[
0
]],
[
"ReduceSum_for_mask_output_0"
],
axes
=
1
,
keepdims
=
0
,
index
=
cast_node
.
index
+
1
)
return
reducesum_node
.
outputs
[
0
]
# bert
# for i in range(12):
# fuse_one_attention(om, f"/bert/encoder/layer.{i}/attention/self/Softmax", "text_token_mask", num_heads=12)
new_mask
=
create_new_attention_mask
()
for
i
in
range
(
6
):
# /transformer/encoder
# fuse_one_attention(om, f"/transformer/encoder/text_layers.{i}/self_attn/Softmax", "text_token_mask", num_heads=4)
# /transformer/decoder
fuse_one_attention
(
om
,
f
"/transformer/decoder/layers.
{
i
}
/self_attn/Softmax"
,
new_mask
,
num_heads
=
8
)
fuse_one_attention
(
om
,
f
"/transformer/decoder/layers.
{
i
}
/ca_text/Softmax"
,
new_mask
,
num_heads
=
8
)
om
.
update_map
()
def
optimize_backbone_attention
(
om
:
ONNXModifier
):
def
get_original_mask
(
mask_name
,
name_prefix
):
mask_value
=
om
.
get_initializer_value
(
mask_name
)
orig_mask
=
np
.
where
(
mask_value
==
0
,
1
,
0
).
astype
(
np
.
bool_
)
orig_mask_init
=
om
.
create_initializer
(
f
"
{
name_prefix
}
/mask"
,
orig_mask
)
return
orig_mask_init
.
name
def
_fuse_one_attention
(
softmax_name
:
str
):
name_prefix
=
'/'
.
join
(
softmax_name
.
split
(
'/'
)[:
-
1
])
softmax_node
=
om
.
get_node
(
softmax_name
)
tmp_node
=
om
.
get_prev_nodes
(
softmax_node
)[
0
]
pos_bias_init
=
None
if
tmp_node
.
op_type
==
"Reshape"
:
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Add"
pos_bias_init
=
om
.
get_initializer
(
tmp_node
.
inputs
[
1
])
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Reshape"
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Add"
mask
=
get_original_mask
(
tmp_node
.
inputs
[
1
],
name_prefix
)
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"MatMul"
qk_matmul
=
tmp_node
tmp_node
=
om
.
get_from_node
(
qk_matmul
.
inputs
[
0
])
assert
tmp_node
.
op_type
==
"Mul"
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Gather"
q_gather_node
=
tmp_node
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Transpose"
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Reshape"
reshape_node
=
tmp_node
tmp_node
=
om
.
get_from_node
(
qk_matmul
.
inputs
[
1
])
assert
tmp_node
.
op_type
==
"Transpose"
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Gather"
k_gather_node
=
tmp_node
tmp_node
=
om
.
get_next_nodes
(
softmax_node
)[
0
]
assert
tmp_node
.
op_type
==
"MatMul"
v_gather_node
=
om
.
get_from_node
(
tmp_node
.
inputs
[
1
])
assert
v_gather_node
.
op_type
==
"Gather"
tmp_node
=
om
.
get_next_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Transpose"
tmp_node
=
om
.
get_next_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Reshape"
mha_out
=
tmp_node
.
outputs
[
0
]
old_dst_shape
=
om
.
get_initializer_value
(
reshape_node
.
inputs
[
1
])
b
,
s
,
_
,
h
,
d
=
old_dst_shape
new_dst_shape
=
[
b
,
s
,
_
,
h
*
d
]
new_dst_shape_init
=
om
.
create_initializer
(
f
"
{
name_prefix
}
/qkv_hidden_states_shape"
,
np
.
array
(
new_dst_shape
,
np
.
int64
))
reshape_node
.
set_input
(
1
,
new_dst_shape_init
.
name
)
for
node
in
[
q_gather_node
,
k_gather_node
,
v_gather_node
]:
node
.
set_input
(
0
,
reshape_node
.
outputs
[
0
])
node
.
set_attribute
(
"axis"
,
2
)
mha_name
=
f
"
{
name_prefix
}
/MultiHeadAttention"
inputs
=
[
q_gather_node
.
outputs
[
0
],
k_gather_node
.
outputs
[
0
],
v_gather_node
.
outputs
[
0
],
mask
]
if
pos_bias_init
is
not
None
:
inputs
.
append
(
pos_bias_init
.
name
)
mha_node
=
om
.
create_node
(
"MultiHeadAttention"
,
mha_name
,
inputs
,
[
mha_name
+
'_output_0'
],
num_heads
=
h
,
domain
=
"com.microsoft"
,
index
=
softmax_node
.
index
)
mha_next_node
=
om
.
get_to_nodes
(
mha_out
)[
0
]
mha_next_node
.
replace_input
(
mha_out
,
mha_node
.
outputs
[
0
])
num_layers
=
4
for
l
in
range
(
num_layers
):
num_blocks
=
18
if
l
==
2
else
2
for
b
in
range
(
num_blocks
):
_fuse_one_attention
(
f
"/backbone/backbone.0/layers.
{
l
}
/blocks.
{
b
}
/attn/softmax/Softmax"
)
def
optimize_ms_deform_attn
(
om
:
ONNXModifier
):
def
fuse_ms_deform_attn
(
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
,
output
):
value_next_node
=
om
.
get_to_nodes
(
value
)[
0
]
index
=
value_next_node
.
index
name_prefix
=
'/'
.
join
(
value
.
split
(
'/'
)[:
-
1
])
node_name
=
f
"
{
name_prefix
}
/MSDeformAttn"
fusion_node
=
om
.
create_node
(
"MSDeformAttn"
,
node_name
,
[
value
,
spatial_shapes
,
level_start_index
,
sampling_locations
,
attention_weights
],
[
f
"
{
node_name
}
_output_0"
],
index
=
index
)
next_nodes
=
om
.
get_to_nodes
(
output
)
for
node
in
next_nodes
:
node
.
replace_input
(
output
,
fusion_node
.
outputs
[
0
])
spatial_shapes_int
=
om
.
create_initializer
(
"/transformer/spatial_shapes"
,
np
.
array
([(
100
,
150
),
(
50
,
75
),
(
25
,
38
),
(
13
,
19
)],
dtype
=
np
.
int64
)
)
level_start_index_init
=
om
.
create_initializer
(
"/transformer/level_start_index"
,
np
.
array
([
0
,
15000
,
18750
,
19700
],
dtype
=
np
.
int64
)
)
for
i
in
range
(
6
):
fuse_ms_deform_attn
(
f
"/transformer/encoder/layers.
{
i
}
/self_attn/Reshape_output_0"
,
spatial_shapes_int
.
name
,
level_start_index_init
.
name
,
f
"/transformer/encoder/layers.
{
i
}
/self_attn/Add_output_0"
,
f
"/transformer/encoder/layers.
{
i
}
/self_attn/Reshape_3_output_0"
,
f
"/transformer/encoder/layers.
{
i
}
/self_attn/Transpose_9_output_0"
)
fuse_ms_deform_attn
(
f
"/transformer/decoder/layers.
{
i
}
/cross_attn/Reshape_output_0"
,
spatial_shapes_int
.
name
,
level_start_index_init
.
name
,
f
"/transformer/decoder/layers.
{
i
}
/cross_attn/Add_output_0"
,
f
"/transformer/decoder/layers.
{
i
}
/cross_attn/Reshape_3_output_0"
,
f
"/transformer/decoder/layers.
{
i
}
/cross_attn/Transpose_9_output_0"
)
om
.
update_map
()
def
optimize_bidirect_attention
(
om
:
ONNXModifier
):
for
i
in
range
(
6
):
reduce_max_name
=
f
"/transformer/encoder/fusion_layers.
{
i
}
/attn/ReduceMax_1"
reduce_max_node
=
om
.
get_node
(
reduce_max_name
)
next_node
=
om
.
get_next_nodes
(
reduce_max_node
)[
0
]
assert
next_node
.
op_type
==
"Sub"
name_prefix
=
'/'
.
join
(
reduce_max_name
.
split
(
'/'
)[:
-
1
])
matmul_name
=
f
"
{
name_prefix
}
/identity_MatMul"
matmul_init
=
om
.
create_initializer
(
matmul_name
+
"_B"
,
np
.
diag
(
np
.
array
([
1
]
*
1
)).
astype
(
np
.
float32
))
matmul_node
=
om
.
create_node
(
"MatMul"
,
matmul_name
,
[
reduce_max_node
.
outputs
[
0
],
matmul_init
.
name
],
[
f
"
{
matmul_name
}
_output_0"
],
index
=
reduce_max_node
.
index
+
1
)
next_node
.
set_input
(
1
,
matmul_node
.
outputs
[
0
])
def
main
():
input_onnx_path
=
sys
.
argv
[
1
]
output_onnx_path
=
sys
.
argv
[
2
]
# input_onnx_path = "ground_sim.onnx"
# output_onnx_path = "ground_sim_0424_2nd.onnx"
om
=
ONNXModifier
(
input_onnx_path
)
optimize_where_ndoes
(
om
)
# 1. 替换where节点
optimize_transpose_nodes
(
om
)
# 2. 优化transpose节点
optmize_sin_cos_block
(
om
)
# 3. 优化位置编码
# om.add_opset_import("com.microsoft", 1)
# optimize_normal_attention(om) # 4. 融合bert、transformer中的mha
# optimize_ms_deform_attn(om) # 5. 融合多尺度可变形注意力
# optimize_backbone_attention(om) # 6. 融合backbone中的注意力
optimize_bidirect_attention
(
om
)
# 7. 优化双向注意力
om
.
save
(
output_onnx_path
,
save_as_external_data
=
False
)
if
__name__
==
"__main__"
:
main
()
migraphx_infer/modify_onnx1.py
0 → 100644
View file @
a1865640
import
sys
import
numpy
as
np
from
onnx_modifier
import
ONNXModifier
def
change_inf_to_value
(
om
:
ONNXModifier
):
records
=
set
()
for
where_node
in
om
.
get_nodes
(
"Where"
):
for
input_name
in
where_node
.
inputs
[
1
:]:
init
=
om
.
get_initializer
(
input_name
)
if
init
is
None
:
continue
assert
input_name
==
init
.
name
init_name
=
input_name
if
init_name
in
records
:
continue
# info = np.finfo(np.float32)
info
=
np
.
finfo
(
np
.
float16
)
data
=
om
.
get_initializer_value
(
init
.
name
)
if
data
.
size
>
1
:
continue
if
data
==
np
.
inf
:
om
.
set_initializer_value
(
init_name
,
np
.
array
(
info
.
max
,
dtype
=
np
.
float32
))
elif
data
==
-
np
.
inf
:
om
.
set_initializer_value
(
init_name
,
np
.
array
(
info
.
min
,
dtype
=
np
.
float32
))
else
:
continue
# print("Changed value:", init_name)
records
.
add
(
init_name
)
# def optimize_where_ndoes(om: ONNXModifier):
# """Where节点等价替换
# (1) condition为initializer, X为0, Y为输入数据:
# Where(cond, X, Y) ==> Mul(Y, ~cond)
# (2) condition为initializer, X为负无穷, Y为输入数据
# Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
# (3) condition为真实输入, X为负无穷, Y为输入数据
# Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
# cases:
# 1. Where(cond, -inf, input)
# a. /transformer/encoder/fusion_layers.*/attn/Where
# b. /transformer/encoder/fusion_layers.*/attn/Where_1
# c. /class_embed.0_*/Where: Where(cond, -inf, input)
# 2. Where(cond, 0, input):
# a. /transformer/encoder/layers.*/self_attn/Where
# b. /transformer/decoder/layers.*/cross_attn/Where
# """
# for where_node in om.get_nodes("Where"):
# where_name = where_node.name
# # print("Process where node:", where_name)
# x_value = om.get_initializer_value(where_node.inputs[1])
# assert x_value.size == 1
# assert x_value == np.array(0.0, dtype=np.float32) or \
# x_value == np.array(-np.inf, dtype=np.float32)
# cond_init = om.get_initializer(where_node.inputs[0])
# if cond_init is not None:
# cond_value = om.get_initializer_value(where_node.inputs[0])
# if x_value == np.array(0.0, dtype=np.float32):
# # Where(cond, X, Y) ==> Mul(Y, ~cond)
# mul_name = where_name.replace("Where", "NewMul")
# mul_b_init = om.create_initializer(mul_name + "_B",
# (~cond_value).astype(np.float32))
# mul_node = om.create_node("Mul",
# mul_name,
# [where_node.inputs[2], mul_b_init.name],
# [mul_name+"_output_0"],
# index=where_node.index)
# next_nodes = where_node.next_nodes
# for next_node in next_nodes:
# next_node.replace_input(where_node.outputs[0], mul_node.outputs[0])
# elif x_value == np.array(-np.inf, dtype=np.float32):
# # Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
# sub_name = where_name.replace("Where", "NewSub")
# sub_b_init = om.create_initializer(
# sub_name + "_B",
# np.where(cond_value.astype(np.float32),
# np.finfo(np.float16).max, 0.0).astype(np.float32)
# )
# sub_node = om.create_node("Sub",
# sub_name,
# [where_node.inputs[2], sub_b_init.name],
# [sub_name+"_output_0"],
# index=where_node.index)
# next_nodes = where_node.next_nodes
# for next_node in next_nodes:
# next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
# else:
# # Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
# assert x_value == np.array(-np.inf, dtype=np.float32)
# cast_name = where_name.replace("Where", "NewCast")
# mul_name = where_name.replace("Where", "NewMul")
# sub_name = where_name.replace("Where", "NewSub")
# cast_node = om.create_node("Cast",
# cast_name,
# [where_node.inputs[0]],
# [cast_name+"_output_0"],
# to=1,
# index=where_node.index)
# mul_b_init = om.create_initializer(mul_name + "_B",
# np.array([np.finfo(np.float16).max], np.float32))
# mul_node = om.create_node("Mul",
# mul_name,
# [cast_node.outputs[0], mul_b_init.name],
# [mul_name+"_output_0"],
# index=cast_node.index+1)
# sub_node = om.create_node("Sub",
# sub_name,
# [where_node.inputs[2], mul_node.outputs[0]],
# [sub_name+"_output_0"],
# index=mul_node.index+1)
# next_nodes = where_node.next_nodes
# for next_node in next_nodes:
# next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
# om.update_map()
def
optimize_where_ndoes
(
om
:
ONNXModifier
):
"""Where节点等价替换 (加入安全校验版本)"""
for
where_node
in
om
.
get_nodes
(
"Where"
):
where_name
=
where_node
.
name
# 1. 安全获取 X 的值,如果 X 不是常量(initializer),直接跳过不优化
x_init
=
om
.
get_initializer
(
where_node
.
inputs
[
1
])
if
x_init
is
None
:
continue
x_value
=
om
.
get_initializer_value
(
where_node
.
inputs
[
1
])
# 2. 避免 assert 崩溃:如果 size 不为 1,说明不是我们要找的 Attention Mask 节点,跳过
if
x_value
.
size
!=
1
:
continue
# 3. 判断是否符合优化条件(0.0 或 -inf),不符合直接跳过
is_zero
=
(
x_value
==
np
.
array
(
0.0
,
dtype
=
np
.
float32
))
is_neg_inf
=
(
x_value
==
np
.
array
(
-
np
.
inf
,
dtype
=
np
.
float32
))
if
not
(
is_zero
or
is_neg_inf
):
continue
cond_init
=
om
.
get_initializer
(
where_node
.
inputs
[
0
])
if
cond_init
is
not
None
:
cond_value
=
om
.
get_initializer_value
(
where_node
.
inputs
[
0
])
if
is_zero
:
# Where(cond, X, Y) ==> Mul(Y, ~cond)
mul_name
=
where_name
.
replace
(
"Where"
,
"NewMul"
)
mul_b_init
=
om
.
create_initializer
(
mul_name
+
"_B"
,
(
~
cond_value
).
astype
(
np
.
float32
))
mul_node
=
om
.
create_node
(
"Mul"
,
mul_name
,
[
where_node
.
inputs
[
2
],
mul_b_init
.
name
],
[
mul_name
+
"_output_0"
],
index
=
where_node
.
index
)
next_nodes
=
where_node
.
next_nodes
for
next_node
in
next_nodes
:
next_node
.
replace_input
(
where_node
.
outputs
[
0
],
mul_node
.
outputs
[
0
])
elif
is_neg_inf
:
# Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
sub_name
=
where_name
.
replace
(
"Where"
,
"NewSub"
)
sub_b_init
=
om
.
create_initializer
(
sub_name
+
"_B"
,
np
.
where
(
cond_value
.
astype
(
np
.
float32
),
np
.
finfo
(
np
.
float16
).
max
,
0.0
).
astype
(
np
.
float32
)
)
sub_node
=
om
.
create_node
(
"Sub"
,
sub_name
,
[
where_node
.
inputs
[
2
],
sub_b_init
.
name
],
[
sub_name
+
"_output_0"
],
index
=
where_node
.
index
)
next_nodes
=
where_node
.
next_nodes
for
next_node
in
next_nodes
:
next_node
.
replace_input
(
where_node
.
outputs
[
0
],
sub_node
.
outputs
[
0
])
else
:
# Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
# 当 condition 不是 initializer 时,只处理 -inf 的情况
if
not
is_neg_inf
:
continue
cast_name
=
where_name
.
replace
(
"Where"
,
"NewCast"
)
mul_name
=
where_name
.
replace
(
"Where"
,
"NewMul"
)
sub_name
=
where_name
.
replace
(
"Where"
,
"NewSub"
)
cast_node
=
om
.
create_node
(
"Cast"
,
cast_name
,
[
where_node
.
inputs
[
0
]],
[
cast_name
+
"_output_0"
],
to
=
1
,
index
=
where_node
.
index
)
mul_b_init
=
om
.
create_initializer
(
mul_name
+
"_B"
,
np
.
array
([
np
.
finfo
(
np
.
float16
).
max
],
np
.
float32
))
mul_node
=
om
.
create_node
(
"Mul"
,
mul_name
,
[
cast_node
.
outputs
[
0
],
mul_b_init
.
name
],
[
mul_name
+
"_output_0"
],
index
=
cast_node
.
index
+
1
)
sub_node
=
om
.
create_node
(
"Sub"
,
sub_name
,
[
where_node
.
inputs
[
2
],
mul_node
.
outputs
[
0
]],
[
sub_name
+
"_output_0"
],
index
=
mul_node
.
index
+
1
)
next_nodes
=
where_node
.
next_nodes
for
next_node
in
next_nodes
:
next_node
.
replace_input
(
where_node
.
outputs
[
0
],
sub_node
.
outputs
[
0
])
om
.
update_map
()
def
optimize_transpose_nodes
(
om
:
ONNXModifier
):
transpose_list
=
[
"/transformer/encoder/Transpose"
,
"/transformer/encoder/Transpose_1"
,
"/transformer/encoder/Transpose_2"
,
"/transformer/encoder/Transpose_3"
,
"/transformer/encoder/Transpose_4"
,
"/transformer/encoder/Transpose_5"
,
"/transformer/encoder/Transpose_6"
,
"/transformer/encoder/Transpose_7"
,
"/transformer/encoder/Transpose_8"
,
"/transformer/encoder/Transpose_9"
,
"/transformer/encoder/Transpose_10"
,
"/transformer/encoder/Transpose_11"
,
"/transformer/decoder/layers.0/Transpose"
,
"/transformer/decoder/layers.0/Transpose_1"
,
"/transformer/decoder/layers.0/Transpose_2"
,
"/transformer/decoder/layers.1/Transpose"
,
"/transformer/decoder/layers.1/Transpose_1"
,
"/transformer/decoder/layers.1/Transpose_2"
,
"/transformer/decoder/layers.2/Transpose"
,
"/transformer/decoder/layers.2/Transpose_1"
,
"/transformer/decoder/layers.2/Transpose_2"
,
"/transformer/decoder/layers.3/Transpose"
,
"/transformer/decoder/layers.3/Transpose_1"
,
"/transformer/decoder/layers.3/Transpose_2"
,
"/transformer/decoder/layers.4/Transpose"
,
"/transformer/decoder/layers.4/Transpose_1"
,
"/transformer/decoder/layers.4/Transpose_2"
,
"/transformer/decoder/layers.5/Transpose"
,
"/transformer/decoder/layers.5/Transpose_1"
,
"/transformer/decoder/layers.5/Transpose_2"
,
"/transformer/Transpose_8"
,
"/transformer/decoder/Transpose"
,
"/transformer/decoder/Transpose_1"
,
"/transformer/decoder/Transpose_2"
,
"/transformer/decoder/Transpose_3"
,
"/transformer/decoder/Transpose_4"
,
"/transformer/decoder/Transpose_5"
,
"/transformer/decoder/Transpose_6"
,
"/transformer/decoder/Transpose_7"
,
"/transformer/decoder/Transpose_8"
,
"/transformer/decoder/Transpose_9"
,
"/transformer/decoder/Transpose_10"
,
"/transformer/decoder/Transpose_11"
]
for
name
in
transpose_list
:
node
=
om
.
get_node
(
name
)
# 安全校验:如果找不到这个节点,说明当前模型不需要优化这个点,跳过
if
node
is
None
:
continue
if
'perm'
in
node
.
attrs
and
(
node
.
attrs
[
'perm'
]
==
[
1
,
0
,
2
]
or
node
.
attrs
[
'perm'
]
==
[
1
,
0
,
2
,
3
]):
next_nodes
=
om
.
get_next_nodes
(
node
)
for
node_
in
next_nodes
:
node_
.
replace_input
(
node
.
outputs
[
0
],
node
.
inputs
[
0
])
# modify /transformer/encoder/text_layers.*/self_attn/Reshape_4
shape_init1
=
om
.
create_initializer
(
"/transformer/encoder/text_layers.x/self_attn/des_shape"
,
np
.
array
([
1
,
4
,
256
],
np
.
int64
)
)
for
i
in
range
(
6
):
reshape_node
=
om
.
get_node
(
f
"/transformer/encoder/text_layers.
{
i
}
/self_attn/Reshape_4"
)
if
reshape_node
is
not
None
:
reshape_node
.
set_input
(
1
,
shape_init1
.
name
)
# modify /transformer/enc_out_class_embed/Transpose
trans_node
=
om
.
get_node
(
"/transformer/enc_out_class_embed/Transpose"
)
if
trans_node
is
not
None
:
trans_node
.
set_attribute
(
"perm"
,
[
0
,
2
,
1
])
# modify /transformer/decoder/Reshape_* # 安全校验:避免写死的随机变量名 _v_5525 引发崩溃
init_5525
=
om
.
get_initializer
(
"_v_5525"
)
if
init_5525
is
not
None
:
om
.
set_initializer_value
(
"_v_5525"
,
np
.
array
([
1
,
900
,
-
1
],
np
.
int64
))
# modify /transformer/decoder/layers.*/self_attn/Reshape_4
# modify /transformer/decoder/layers.*/ca_text/Reshape_6
shape_init3
=
om
.
create_initializer
(
"/transformer/decoder/layers.x/self_attn_ca_text/des_shape"
,
np
.
array
([
1
,
900
,
256
],
np
.
int64
)
)
for
i
in
range
(
6
):
reshape_node1
=
om
.
get_node
(
f
"/transformer/decoder/layers.
{
i
}
/self_attn/Reshape_4"
)
if
reshape_node1
is
not
None
:
reshape_node1
.
set_input
(
1
,
shape_init3
.
name
)
reshape_node2
=
om
.
get_node
(
f
"/transformer/decoder/layers.
{
i
}
/ca_text/Reshape_6"
)
if
reshape_node2
is
not
None
:
reshape_node2
.
set_input
(
1
,
shape_init3
.
name
)
# modify /transformer/decoder/layers.0/Add
# modify /transformer/decoder/layers.0/Add_1
init_name
=
"/transformer/Tile_1_output_0"
tile_init
=
om
.
get_initializer
(
init_name
)
if
tile_init
is
not
None
:
add_value
=
om
.
get_initializer_value
(
init_name
)
om
.
set_initializer_value
(
init_name
,
np
.
ascontiguousarray
(
add_value
.
transpose
(
1
,
0
,
2
)))
om
.
update_map
()
# 将形状推断包起来,防止自定义算子(MSDeformAttn)导致推理失败崩溃
try
:
om
.
infer_shape
(
strict_mode
=
False
)
except
Exception
as
e
:
print
(
f
"[Warning] infer_shape 跳过 (可能由于自定义算子引起). 详细信息:
{
e
}
"
)
def
optmize_sin_cos_block
(
om
:
ONNXModifier
):
node_pairs
=
[
(
"/transformer/decoder/Gather_1"
,
"/transformer/decoder/ref_point_head/layers.0/MatMul"
),
(
"/transformer/decoder/Gather_6"
,
"/transformer/decoder/ref_point_head/layers.0_1/MatMul"
),
(
"/transformer/decoder/Gather_11"
,
"/transformer/decoder/ref_point_head/layers.0_2/MatMul"
),
(
"/transformer/decoder/Gather_16"
,
"/transformer/decoder/ref_point_head/layers.0_3/MatMul"
),
(
"/transformer/decoder/Gather_21"
,
"/transformer/decoder/ref_point_head/layers.0_4/MatMul"
),
(
"/transformer/decoder/Gather_26"
,
"/transformer/decoder/ref_point_head/layers.0_5/MatMul"
),
]
# 提前创建一些公用的 initializer
unsqueeze_axes_init1
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/unsqueeze_axes1"
,
np
.
array
([
3
,
4
],
np
.
int64
))
slice_axes_init
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_axes"
,
np
.
array
([
4
],
np
.
int64
))
slice_steps_init
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_steps"
,
np
.
array
([
1
],
np
.
int64
))
slice_starts_init1
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_starts1"
,
np
.
array
([
0
],
np
.
int64
))
slice_ends_init1
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_ends1"
,
np
.
array
([
1
],
np
.
int64
))
slice_starts_init2
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_steps2"
,
np
.
array
([
1
],
np
.
int64
))
slice_ends_init2
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/slice_ends2"
,
np
.
array
([
2
],
np
.
int64
))
reshape_init
=
om
.
create_initializer
(
"/transformer/decoder/sin_cos_block/reshape_dst_shape"
,
np
.
array
([
1
,
900
,
-
1
],
np
.
int64
))
for
i
,
(
gather_name
,
matmul_name
)
in
enumerate
(
node_pairs
):
gather_node
=
om
.
get_node
(
gather_name
)
matmul_node
=
om
.
get_node
(
matmul_name
)
# 【安全校验】:如果找不到这一对节点,说明不需要/无法优化这个 block,直接跳过
if
gather_node
is
None
or
matmul_node
is
None
:
continue
try
:
next_node
=
om
.
get_next_nodes
(
gather_node
)[
0
]
if
next_node
.
op_type
!=
"Mul"
:
continue
mul_init_value
=
om
.
get_initializer_value
(
next_node
.
inputs
[
1
])
if
mul_init_value
.
size
!=
1
:
continue
next_node
=
om
.
get_next_nodes
(
next_node
)[
0
]
if
next_node
.
op_type
!=
"Unsqueeze"
:
continue
next_node
.
set_inputs
([
gather_node
.
inputs
[
0
],
unsqueeze_axes_init1
.
name
])
next_node
=
om
.
get_next_nodes
(
next_node
)[
0
]
if
next_node
.
op_type
!=
"Div"
:
continue
div_init_value
=
om
.
get_initializer_value
(
next_node
.
inputs
[
1
])
new_value
=
(
div_init_value
/
mul_init_value
).
reshape
(
1
,
1
,
1
,
64
,
2
)
new_init
=
om
.
create_initializer
(
next_node
.
name
+
"_B"
,
new_value
)
next_node
.
set_input
(
1
,
new_init
.
name
)
next_nodes
=
om
.
get_next_nodes
(
next_node
)
if
len
(
next_nodes
)
!=
2
or
not
all
(
x
.
op_type
==
'Slice'
for
x
in
next_nodes
):
continue
sin_node
,
cos_node
=
None
,
None
for
j
,
slice_node
in
enumerate
(
next_nodes
):
slice_node
.
set_inputs
([
slice_node
.
inputs
[
0
],
slice_starts_init1
.
name
if
j
==
0
else
slice_starts_init2
.
name
,
slice_ends_init1
.
name
if
j
==
0
else
slice_ends_init2
.
name
,
slice_axes_init
.
name
,
slice_steps_init
.
name
])
n_node
=
om
.
get_next_nodes
(
slice_node
)[
0
]
if
n_node
.
op_type
==
"Sin"
:
sin_node
=
n_node
elif
n_node
.
op_type
==
"Cos"
:
cos_node
=
n_node
else
:
raise
RuntimeError
(
"match fail!"
)
n_node
=
om
.
get_next_nodes
(
n_node
)[
0
]
n_node
=
om
.
get_next_nodes
(
n_node
)[
0
]
next_node
=
n_node
# Concat node
if
next_node
.
op_type
!=
"Concat"
:
continue
next_node
.
set_inputs
([
sin_node
.
outputs
[
0
],
cos_node
.
outputs
[
0
]])
next_node
.
set_attribute
(
"axis"
,
4
)
next_node
=
om
.
get_next_nodes
(
next_node
)[
0
]
if
next_node
.
op_type
!=
"Reshape"
:
continue
next_node
.
set_input
(
1
,
reshape_init
.
name
)
matmul_node
.
set_input
(
0
,
next_node
.
outputs
[
0
])
if
i
==
0
:
mm_b_value
=
om
.
get_initializer_value
(
matmul_node
.
inputs
[
1
])
mm_b_value
=
np
.
concatenate
([
mm_b_value
[
128
:
256
,
...],
mm_b_value
[
0
:
128
,
...],
mm_b_value
[
256
:,
...]],
axis
=
0
)
om
.
set_initializer_value
(
matmul_node
.
inputs
[
1
],
mm_b_value
)
except
Exception
as
e
:
# 如果匹配过程中发生任何形状或节点断层的意外,静默跳过这个 block
continue
om
.
update_map
()
try
:
om
.
infer_shape
(
strict_mode
=
False
)
except
:
pass
def
fuse_one_attention
(
om
:
ONNXModifier
,
softmax_name
:
str
,
new_mask
:
bool
=
None
,
num_heads
:
int
=
12
):
softmax_node
=
om
.
get_node
(
softmax_name
)
tmp_node
=
om
.
get_prev_nodes
(
softmax_node
)[
0
]
assert
tmp_node
.
op_type
in
[
"MatMul"
,
"Add"
]
mask
=
None
if
tmp_node
.
op_type
==
"Add"
:
mask_node
=
tmp_node
tmp_node
=
om
.
get_from_node
(
mask_node
.
inputs
[
0
])
if
tmp_node
.
op_type
==
"Div"
:
tmp_node
=
om
.
get_from_node
(
tmp_node
.
inputs
[
0
])
assert
tmp_node
.
op_type
==
"MatMul"
mask
=
mask_node
.
inputs
[
1
]
assert
new_mask
is
not
None
tmp_node1
=
om
.
get_from_node
(
tmp_node
.
inputs
[
0
])
if
tmp_node1
.
op_type
==
"Mul"
:
tmp_node1
=
om
.
get_prev_nodes
(
tmp_node1
)[
0
]
tmp_node2
=
om
.
get_from_node
(
tmp_node
.
inputs
[
1
])
assert
tmp_node1
.
op_type
==
tmp_node2
.
op_type
==
"Transpose"
tmp_node1
=
om
.
get_prev_nodes
(
tmp_node1
)[
0
]
tmp_node2
=
om
.
get_prev_nodes
(
tmp_node2
)[
0
]
assert
tmp_node1
.
op_type
==
tmp_node2
.
op_type
==
"Reshape"
q
,
k
=
tmp_node1
.
inputs
[
0
],
tmp_node2
.
inputs
[
0
]
tmp_node
=
om
.
get_next_nodes
(
softmax_node
)[
0
]
assert
tmp_node
.
op_type
==
"MatMul"
tmp_node3
=
om
.
get_from_node
(
tmp_node
.
inputs
[
1
])
if
tmp_node3
is
not
None
:
assert
tmp_node3
.
op_type
==
"Transpose"
tmp_node3
=
om
.
get_prev_nodes
(
tmp_node3
)[
0
]
assert
tmp_node3
.
op_type
==
"Reshape"
v
=
tmp_node3
.
inputs
[
0
]
else
:
v_init
=
om
.
get_initializer
(
tmp_node
.
inputs
[
1
])
v_init_value
=
om
.
get_initializer_value
(
tmp_node
.
inputs
[
1
])
v_init_value
=
v_init_value
[
None
,
...].
transpose
(
0
,
2
,
1
,
3
)
B
,
S
,
H
,
D
=
v_init_value
.
shape
v_init_value
=
np
.
ascontiguousarray
(
v_init_value
.
reshape
(
B
,
S
,
H
*
D
))
om
.
set_initializer_value
(
tmp_node
.
inputs
[
1
],
v_init_value
)
v
=
v_init
.
name
tmp_node
=
om
.
get_next_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Transpose"
tmp_node
=
om
.
get_next_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Reshape"
mha_next_node
=
om
.
get_next_nodes
(
tmp_node
)[
0
]
if
mha_next_node
.
op_type
==
"Gemm"
:
gemm_next_node
=
om
.
get_next_nodes
(
mha_next_node
)[
0
]
assert
gemm_next_node
.
op_type
==
"Reshape"
reshape_next_node
=
om
.
get_next_nodes
(
gemm_next_node
)[
0
]
assert
reshape_next_node
.
op_type
==
"Add"
else
:
assert
mha_next_node
.
op_type
==
"MatMul"
name_prefix
=
'/'
.
join
(
softmax_name
.
split
(
'/'
)[:
-
1
])
mha_name
=
f
"
{
name_prefix
}
/MultiHeadAttention"
mha_node
=
om
.
create_node
(
"MultiHeadAttention"
,
mha_name
,
[
q
,
k
,
v
]
if
mask
is
None
else
[
q
,
k
,
v
,
new_mask
],
[
mha_name
+
'_output_0'
],
num_heads
=
num_heads
,
domain
=
"com.microsoft"
,
index
=
mha_next_node
.
index
-
1
)
mha_next_node
.
replace_input
(
mha_next_node
.
inputs
[
0
],
mha_node
.
outputs
[
0
])
if
mha_next_node
.
op_type
==
"Gemm"
:
weights
=
om
.
get_initializer_value
(
mha_next_node
.
inputs
[
1
])
transB
=
mha_next_node
.
attrs
[
"transB"
]
assert
transB
==
1
weights
=
np
.
ascontiguousarray
(
weights
.
transpose
(
1
,
0
))
om
.
set_initializer_value
(
mha_next_node
.
inputs
[
1
],
weights
)
new_matmul_name
=
mha_next_node
.
name
.
replace
(
"Gemm"
,
"MatMul(Gemm)"
)
new_matmul_node
=
om
.
create_node
(
"MatMul"
,
new_matmul_name
,
[
mha_node
.
outputs
[
0
],
mha_next_node
.
inputs
[
1
]],
[
new_matmul_name
+
"_output_0"
],
index
=
mha_next_node
.
index
)
new_bias_name
=
mha_next_node
.
name
.
replace
(
"Gemm"
,
"Add(Gemm)"
)
new_add_node
=
om
.
create_node
(
"Add"
,
new_bias_name
,
[
new_matmul_node
.
outputs
[
0
],
mha_next_node
.
inputs
[
2
]],
[
new_bias_name
+
"_output_0"
],
index
=
new_matmul_node
.
index
+
1
)
reshape_next_node
.
replace_input
(
gemm_next_node
.
outputs
[
0
],
new_add_node
.
outputs
[
0
])
def
optimize_normal_attention
(
om
:
ONNXModifier
):
def
create_new_attention_mask
():
mask_next_node
=
om
.
get_to_nodes
(
"attention_mask"
)[
0
]
cast_node
=
om
.
create_node
(
"Cast"
,
"Cast_for_attention_mask"
,
[
"attention_mask"
],
[
"Cast_for_attention_mask_output_0"
],
to
=
1
,
index
=
mask_next_node
.
index
)
reducesum_node
=
om
.
create_node
(
"ReduceSum"
,
"ReduceSum_for_mask"
,
[
cast_node
.
outputs
[
0
]],
[
"ReduceSum_for_mask_output_0"
],
axes
=
1
,
keepdims
=
0
,
index
=
cast_node
.
index
+
1
)
return
reducesum_node
.
outputs
[
0
]
# bert
# for i in range(12):
# fuse_one_attention(om, f"/bert/encoder/layer.{i}/attention/self/Softmax", "text_token_mask", num_heads=12)
new_mask
=
create_new_attention_mask
()
for
i
in
range
(
6
):
# /transformer/encoder
# fuse_one_attention(om, f"/transformer/encoder/text_layers.{i}/self_attn/Softmax", "text_token_mask", num_heads=4)
# /transformer/decoder
fuse_one_attention
(
om
,
f
"/transformer/decoder/layers.
{
i
}
/self_attn/Softmax"
,
new_mask
,
num_heads
=
8
)
fuse_one_attention
(
om
,
f
"/transformer/decoder/layers.
{
i
}
/ca_text/Softmax"
,
new_mask
,
num_heads
=
8
)
om
.
update_map
()
def
optimize_backbone_attention
(
om
:
ONNXModifier
):
def
get_original_mask
(
mask_name
,
name_prefix
):
mask_value
=
om
.
get_initializer_value
(
mask_name
)
orig_mask
=
np
.
where
(
mask_value
==
0
,
1
,
0
).
astype
(
np
.
bool_
)
orig_mask_init
=
om
.
create_initializer
(
f
"
{
name_prefix
}
/mask"
,
orig_mask
)
return
orig_mask_init
.
name
def
_fuse_one_attention
(
softmax_name
:
str
):
name_prefix
=
'/'
.
join
(
softmax_name
.
split
(
'/'
)[:
-
1
])
softmax_node
=
om
.
get_node
(
softmax_name
)
tmp_node
=
om
.
get_prev_nodes
(
softmax_node
)[
0
]
pos_bias_init
=
None
if
tmp_node
.
op_type
==
"Reshape"
:
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Add"
pos_bias_init
=
om
.
get_initializer
(
tmp_node
.
inputs
[
1
])
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Reshape"
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Add"
mask
=
get_original_mask
(
tmp_node
.
inputs
[
1
],
name_prefix
)
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"MatMul"
qk_matmul
=
tmp_node
tmp_node
=
om
.
get_from_node
(
qk_matmul
.
inputs
[
0
])
assert
tmp_node
.
op_type
==
"Mul"
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Gather"
q_gather_node
=
tmp_node
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Transpose"
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Reshape"
reshape_node
=
tmp_node
tmp_node
=
om
.
get_from_node
(
qk_matmul
.
inputs
[
1
])
assert
tmp_node
.
op_type
==
"Transpose"
tmp_node
=
om
.
get_prev_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Gather"
k_gather_node
=
tmp_node
tmp_node
=
om
.
get_next_nodes
(
softmax_node
)[
0
]
assert
tmp_node
.
op_type
==
"MatMul"
v_gather_node
=
om
.
get_from_node
(
tmp_node
.
inputs
[
1
])
assert
v_gather_node
.
op_type
==
"Gather"
tmp_node
=
om
.
get_next_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Transpose"
tmp_node
=
om
.
get_next_nodes
(
tmp_node
)[
0
]
assert
tmp_node
.
op_type
==
"Reshape"
mha_out
=
tmp_node
.
outputs
[
0
]
old_dst_shape
=
om
.
get_initializer_value
(
reshape_node
.
inputs
[
1
])
b
,
s
,
_
,
h
,
d
=
old_dst_shape
new_dst_shape
=
[
b
,
s
,
_
,
h
*
d
]
new_dst_shape_init
=
om
.
create_initializer
(
f
"
{
name_prefix
}
/qkv_hidden_states_shape"
,
np
.
array
(
new_dst_shape
,
np
.
int64
))
reshape_node
.
set_input
(
1
,
new_dst_shape_init
.
name
)
for
node
in
[
q_gather_node
,
k_gather_node
,
v_gather_node
]:
node
.
set_input
(
0
,
reshape_node
.
outputs
[
0
])
node
.
set_attribute
(
"axis"
,
2
)
mha_name
=
f
"
{
name_prefix
}
/MultiHeadAttention"
inputs
=
[
q_gather_node
.
outputs
[
0
],
k_gather_node
.
outputs
[
0
],
v_gather_node
.
outputs
[
0
],
mask
]
if
pos_bias_init
is
not
None
:
inputs
.
append
(
pos_bias_init
.
name
)
mha_node
=
om
.
create_node
(
"MultiHeadAttention"
,
mha_name
,
inputs
,
[
mha_name
+
'_output_0'
],
num_heads
=
h
,
domain
=
"com.microsoft"
,
index
=
softmax_node
.
index
)
mha_next_node
=
om
.
get_to_nodes
(
mha_out
)[
0
]
mha_next_node
.
replace_input
(
mha_out
,
mha_node
.
outputs
[
0
])
num_layers
=
4
for
l
in
range
(
num_layers
):
num_blocks
=
18
if
l
==
2
else
2
for
b
in
range
(
num_blocks
):
_fuse_one_attention
(
f
"/backbone/backbone.0/layers.
{
l
}
/blocks.
{
b
}
/attn/softmax/Softmax"
)
def
optimize_bidirect_attention
(
om
:
ONNXModifier
):
for
i
in
range
(
6
):
reduce_max_name
=
f
"/transformer/encoder/fusion_layers.
{
i
}
/attn/ReduceMax_1"
reduce_max_node
=
om
.
get_node
(
reduce_max_name
)
# 【安全校验】
if
reduce_max_node
is
None
:
continue
next_nodes
=
om
.
get_next_nodes
(
reduce_max_node
)
if
not
next_nodes
:
continue
next_node
=
next_nodes
[
0
]
if
next_node
.
op_type
!=
"Sub"
:
continue
name_prefix
=
'/'
.
join
(
reduce_max_name
.
split
(
'/'
)[:
-
1
])
matmul_name
=
f
"
{
name_prefix
}
/identity_MatMul"
matmul_init
=
om
.
create_initializer
(
matmul_name
+
"_B"
,
np
.
diag
(
np
.
array
([
1
]
*
1
)).
astype
(
np
.
float32
))
matmul_node
=
om
.
create_node
(
"MatMul"
,
matmul_name
,
[
reduce_max_node
.
outputs
[
0
],
matmul_init
.
name
],
[
f
"
{
matmul_name
}
_output_0"
],
index
=
reduce_max_node
.
index
+
1
)
next_node
.
set_input
(
1
,
matmul_node
.
outputs
[
0
])
# def main():
# input_onnx_path = sys.argv[1]
# output_onnx_path = sys.argv[2]
# # input_onnx_path = "ground_sim.onnx"
# # output_onnx_path = "ground_sim_0424_2nd.onnx"
# om = ONNXModifier(input_onnx_path)
# optimize_where_ndoes(om) # 1. 替换where节点
# optimize_transpose_nodes(om) # 2. 优化transpose节点
# optmize_sin_cos_block(om) # 3. 优化位置编码
# # om.add_opset_import("com.microsoft", 1)
# # optimize_normal_attention(om) # 4. 融合bert、transformer中的mha
# # optimize_ms_deform_attn(om) # 5. 融合多尺度可变形注意力
# # optimize_backbone_attention(om) # 6. 融合backbone中的注意力
# optimize_bidirect_attention(om) # 7. 优化双向注意力
# om.save(output_onnx_path, save_as_external_data=False)
def
main
():
# 假设你的原始模型路径
input_onnx_path
=
"../weights/ground_deform_sim.onnx"
# 优化后的模型输出路径
output_onnx_path
=
"../weights_opt/ground_deform_opt.onnx"
print
(
f
"Loading ONNX model from
{
input_onnx_path
}
..."
)
om
=
ONNXModifier
(
input_onnx_path
)
print
(
"1. Optimizing Where nodes (Crucial for FP16 & MIGraphX)..."
)
optimize_where_ndoes
(
om
)
print
(
"2. Optimizing Transpose nodes..."
)
optimize_transpose_nodes
(
om
)
# print("3. Optimizing Sin/Cos positional encoding...")
# optmize_sin_cos_block(om)
# print("4. Optimizing Bidirectional attention...")
# optimize_bidirect_attention(om)
print
(
f
"Saving optimized model to
{
output_onnx_path
}
..."
)
om
.
save
(
output_onnx_path
,
save_as_external_data
=
False
)
print
(
"Optimization Done!"
)
if
__name__
==
"__main__"
:
main
()
migraphx_infer/onnx_modifier.py
0 → 100644
View file @
a1865640
"""
onnx modifier: provide a conviennt way to modify onnx model
1. add node
2. remove node
3. modify node
4. query node
"""
from
collections
import
defaultdict
,
deque
import
os
import
os.path
as
osp
import
shutil
import
tempfile
from
typing
import
List
,
Dict
,
Set
,
Tuple
,
Optional
,
Union
import
uuid
import
warnings
import
numpy
as
np
import
onnx
from
onnx
import
AttributeProto
,
numpy_helper
from
onnx
import
shape_inference
from
onnx.helper
import
make_attribute
,
make_node
,
make_opsetid
,
make_tensor
,
\
tensor_dtype_to_np_dtype
from
onnxconverter_common
import
float16
import
tqdm
# from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
SUPPORT_DTYPES
=
[
'BOOL'
,
'STRING'
,
'BFLOAT16'
,
'DOUBLE'
,
'FLOAT'
,
'FLOAT16'
,
'INT16'
,
'INT32'
,
'INT4'
,
'INT64'
,
'INT8'
,
'UINT16'
,
'UINT32'
,
'UINT4'
,
'UINT64'
,
'UINT8'
,
]
SUPPORT_DTYPES
.
extend
([
dt
.
lower
()
for
dt
in
SUPPORT_DTYPES
])
class
Node
:
def
__init__
(
self
,
onnx_modifier
=
None
,
obj
=
None
,
index
=
None
):
self
.
onnx_modifier
=
onnx_modifier
self
.
obj
=
obj
self
.
index
=
index
@
property
def
name
(
self
):
return
self
.
obj
.
name
@
property
def
op_type
(
self
):
return
self
.
obj
.
op_type
@
property
def
inputs
(
self
):
return
self
.
obj
.
input
@
property
def
outputs
(
self
):
return
self
.
obj
.
output
@
property
def
input_names
(
self
):
return
self
.
inputs
@
property
def
output_names
(
self
):
return
self
.
outputs
def
check_modifier
(
self
):
if
self
.
onnx_modifier
is
None
:
raise
RuntimeError
(
"onnx_modifier is not initialized"
)
@
property
def
prev_nodes
(
self
):
self
.
check_modifier
()
return
self
.
onnx_modifier
.
get_prev_nodes
(
self
)
@
property
def
next_nodes
(
self
):
self
.
check_modifier
()
return
self
.
onnx_modifier
.
get_next_nodes
(
self
)
def
replace_input
(
self
,
old_name
,
new_name
):
assert
old_name
in
self
.
obj
.
input
,
\
f
'"
{
old_name
}
" not in input name list of node named "
{
self
.
name
}
"'
for
i
,
in_name
in
enumerate
(
self
.
obj
.
input
):
if
in_name
==
old_name
:
self
.
set_input
(
i
,
new_name
)
def
set_input
(
self
,
index
,
name
):
# assert index < len(self.obj.input), "index out of range"
# orig_name = self.obj.input[index]
# self.obj.input[index] = name
assert
index
<
len
(
self
.
onnx_modifier
.
graph
.
node
[
self
.
index
].
input
),
"index out of range"
orig_name
=
self
.
onnx_modifier
.
graph
.
node
[
self
.
index
].
input
[
index
]
self
.
onnx_modifier
.
graph
.
node
[
self
.
index
].
input
[
index
]
=
name
self
.
check_modifier
()
# Can not execute connection.pop_to_node() method directly.
# When node inputs contain multiple orig_name, need to remain the node in to_nodes.
if
list
(
self
.
onnx_modifier
.
graph
.
node
[
self
.
index
].
input
).
count
(
orig_name
)
==
0
:
self
.
onnx_modifier
.
connection_map
[
orig_name
].
pop_to_node
(
self
)
if
name
not
in
self
.
onnx_modifier
.
connection_map
:
self
.
onnx_modifier
.
connection_map
[
name
]
=
Connection
(
name
,
self
.
onnx_modifier
)
self
.
onnx_modifier
.
connection_map
[
name
].
add_to_node
(
self
)
def
set_inputs
(
self
,
names
):
assert
len
(
names
)
==
len
(
self
.
obj
.
input
),
"number of inputs does not match"
assert
all
(
isinstance
(
name
,
str
)
for
name
in
names
),
"input names must be strings"
self
.
obj
.
input
[:]
=
names
def
set_output
(
self
,
index
,
name
):
assert
index
<
len
(
self
.
obj
.
output
),
"index out of range"
orig_name
=
self
.
obj
.
output
[
index
]
self
.
obj
.
output
[
index
]
=
name
self
.
check_modifier
()
self
.
onnx_modifier
.
connection_map
[
orig_name
].
clear_from_node
()
if
name
not
in
self
.
onnx_modifier
.
connection_map
:
self
.
onnx_modifier
.
connection_map
[
name
]
=
Connection
(
name
,
self
.
onnx_modifier
)
self
.
onnx_modifier
.
connection_map
[
name
].
set_from_node
(
self
)
def
set_outputs
(
self
,
names
):
assert
len
(
names
)
==
len
(
self
.
obj
.
output
),
"number of outputs does not match"
assert
all
(
isinstance
(
name
,
str
)
for
name
in
names
),
"output names must be strings"
self
.
obj
.
output
[:]
=
names
@
property
def
attrs
(
self
):
attrs
=
{}
for
attr
in
self
.
obj
.
attribute
:
if
attr
.
type
==
AttributeProto
.
FLOAT
:
# 1
value
=
attr
.
f
elif
attr
.
type
==
AttributeProto
.
INT
:
# 2
value
=
attr
.
i
elif
attr
.
type
==
AttributeProto
.
STRING
:
# 3
value
=
attr
.
s
.
decode
(
'utf-8'
)
elif
attr
.
type
==
AttributeProto
.
TENSOR
:
# 4
value
=
numpy_helper
.
to_array
(
attr
.
t
)
elif
attr
.
type
==
AttributeProto
.
FLOATS
:
# 6
value
=
list
(
attr
.
floats
)
elif
attr
.
type
==
AttributeProto
.
INTS
:
# 7
value
=
list
(
attr
.
ints
)
else
:
value
=
f
"Unsupported type:
{
attr
.
type
}
"
attrs
[
attr
.
name
]
=
value
return
attrs
def
set_attribute
(
self
,
name
,
value
,
name2attr
=
None
):
if
not
name2attr
:
name2attr
=
{}
for
attr
in
self
.
obj
.
attribute
:
name2attr
[
attr
.
name
]
=
attr
if
name
in
name2attr
:
if
isinstance
(
value
,
float
):
name2attr
[
name
].
f
=
value
name2attr
[
name
].
type
=
AttributeProto
.
FLOAT
elif
isinstance
(
value
,
int
):
name2attr
[
name
].
i
=
value
name2attr
[
name
].
type
=
AttributeProto
.
INT
elif
isinstance
(
value
,
str
):
name2attr
[
name
].
s
=
value
.
encode
(
'utf-8'
)
name2attr
[
name
].
type
=
AttributeProto
.
STRING
elif
isinstance
(
value
,
np
.
ndarray
):
name2attr
[
name
].
ClearField
(
"t"
)
name2attr
[
name
].
t
.
CopyFrom
(
numpy_helper
.
from_array
(
value
))
elif
isinstance
(
value
,
list
):
is_all_float
=
all
(
isinstance
(
x
,
float
)
for
x
in
value
)
is_all_int
=
all
(
isinstance
(
x
,
int
)
for
x
in
value
)
assert
is_all_float
or
is_all_int
if
is_all_float
:
name2attr
[
name
].
ClearField
(
"floats"
)
name2attr
[
name
].
floats
.
extend
(
value
)
name2attr
[
name
].
type
=
AttributeProto
.
FLOATS
else
:
name2attr
[
name
].
ClearField
(
"ints"
)
name2attr
[
name
].
ints
.
extend
(
value
)
name2attr
[
name
].
type
=
AttributeProto
.
INTS
else
:
if
isinstance
(
value
,
np
.
ndarray
):
value
=
numpy_helper
.
from_array
(
value
)
self
.
obj
.
attribute
.
append
(
make_attribute
(
name
,
value
))
def
set_attributes
(
self
,
attr_dict
):
name2attr
=
{}
for
attr
in
self
.
obj
.
attribute
:
name2attr
[
attr
.
name
]
=
attr
for
name
,
value
in
attr_dict
.
items
():
self
.
set_attribute
(
name
,
value
,
name2attr
)
class
Connection
:
def
__init__
(
self
,
conn_name
,
onnx_modifier
=
None
):
self
.
name
=
conn_name
self
.
onnx_modifier
=
onnx_modifier
self
.
from_node
=
None
self
.
to_nodes
=
[]
self
.
to_node_names
=
set
()
def
check_modifier
(
self
):
if
self
.
onnx_modifier
is
None
:
raise
RuntimeError
(
"onnx_modifier is not initialized"
)
def
set_from_node
(
self
,
node
:
str
|
Node
):
if
isinstance
(
node
,
str
):
self
.
check_modifier
()
_node
=
self
.
onnx_modifier
.
get_node
(
Node
)
assert
node
is
not
None
,
f
'No node named "
{
node
}
" in onnx graph!'
elif
isinstance
(
node
,
Node
):
_node
=
node
else
:
raise
TypeError
(
f
"Connection.set_from_node except input argument type"
\
f
" is str or Node, but received:
{
type
(
node
)
}
"
)
self
.
from_node
=
_node
def
clear_from_node
(
self
):
self
.
from_node
=
None
def
add_to_node
(
self
,
node
:
str
|
Node
):
if
isinstance
(
node
,
str
):
_name
=
node
self
.
check_modifier
()
_node
=
self
.
onnx_modifier
.
get_node
(
Node
)
assert
node
is
not
None
,
f
'No node named "
{
node
}
" in onnx graph!'
elif
isinstance
(
node
,
Node
):
_name
=
node
.
name
_node
=
node
else
:
raise
TypeError
(
f
"Connection.add_to_node except input argument type"
\
f
" is str or Node, but received:
{
type
(
node
)
}
"
)
if
_name
not
in
self
.
to_node_names
:
self
.
to_node_names
.
add
(
_name
)
self
.
to_nodes
.
append
(
_node
)
def
pop_to_node
(
self
,
node
:
str
|
Node
):
if
isinstance
(
node
,
str
):
_name
=
node
self
.
check_modifier
()
_node
=
self
.
onnx_modifier
.
get_node
(
Node
)
assert
node
is
not
None
,
f
'No node named "
{
node
}
" in onnx graph!'
elif
isinstance
(
node
,
Node
):
_name
=
node
.
name
_node
=
node
else
:
raise
TypeError
(
f
"Connection.pop_to_node except input argument type"
\
f
" is str or Node, but received:
{
type
(
node
)
}
"
)
if
_name
not
in
self
.
to_node_names
:
raise
ValueError
(
f
'Node "
{
_name
}
" not in target nodes of connction "
{
self
.
name
}
"!'
)
self
.
to_node_names
.
remove
(
_name
)
for
i
in
range
(
len
(
self
.
to_nodes
)):
if
self
.
to_nodes
[
i
].
name
==
_name
:
return
self
.
to_nodes
.
pop
(
i
)
else
:
raise
RuntimeError
(
"to_nodes dismatch to_node_names!"
)
class
ONNXModifier
:
def
__init__
(
self
,
onnx_path
):
self
.
onnx_path
=
onnx_path
self
.
node_map
=
{}
self
.
initializer_map
=
{}
self
.
sparse_initializer_map
=
{}
self
.
connection_map
=
{}
self
.
value_info_map
=
{}
self
.
parse_onnx
(
self
.
onnx_path
)
def
parse_onnx
(
self
,
onnx_path
):
model
=
onnx
.
load
(
onnx_path
)
self
.
model
=
model
self
.
domain
=
model
.
domain
self
.
graph
=
model
.
graph
self
.
ir_version
=
model
.
ir_version
self
.
mdoel_version
=
model
.
model_version
self
.
opset_import
=
model
.
opset_import
self
.
update_map
()
def
add_node_name_if_nameless
(
self
,
node
:
Node
):
if
not
hasattr
(
self
,
"node_suffixes"
):
self
.
name_suffixes
=
set
()
if
node
.
name
==
""
or
node
.
name
==
None
:
suffix
=
None
while
True
:
suffix
=
uuid
.
uuid4
().
hex
[:
8
]
if
suffix
not
in
self
.
name_suffixes
:
break
node
.
obj
.
name
=
node
.
op_type
+
"_"
+
suffix
def
add_opset_import
(
self
,
domain
:
str
,
version
:
int
):
self
.
model
.
opset_import
.
append
(
make_opsetid
(
domain
,
version
))
@
property
def
input_names
(
self
):
return
[
i
.
name
for
i
in
self
.
graph
.
input
]
@
property
def
output_names
(
self
):
return
[
o
.
name
for
o
in
self
.
graph
.
output
]
def
add_input
(
self
,
name
,
dtype
=
'float32'
,
shape
=
None
):
assert
dtype
in
set
(
SUPPORT_DTYPES
)
self
.
create_value_info
(
name
,
dtype
=
dtype
,
shape
=
shape
)
new_input
=
self
.
value_info_map
.
pop
(
name
)
_new_input
=
self
.
graph
.
value_info
.
pop
()
assert
id
(
new_input
)
==
id
(
_new_input
)
assert
name
==
new_input
.
name
self
.
graph
.
input
.
append
(
new_input
)
return
new_input
def
add_output
(
self
,
name
,
new_name
=
None
,
shape
=
None
):
if
name
not
in
self
.
value_info_map
:
raise
ValueError
(
f
"
{
name
}
not in onnx_modifier.value_info_map"
)
index
=
None
for
i
,
v
in
enumerate
(
self
.
graph
.
value_info
):
if
v
.
name
==
name
:
index
=
i
break
else
:
raise
ValueError
(
f
"
{
name
}
not in model.graph.value_info"
)
value_info
=
self
.
value_info_map
.
pop
(
name
)
assert
value_info
.
name
==
name
assert
id
(
value_info
)
==
id
(
self
.
graph
.
value_info
[
index
])
self
.
graph
.
value_info
.
pop
(
index
)
if
shape
is
not
None
:
tensor_type
=
onnx
.
helper
.
make_tensor_type_proto
(
elem_type
=
value_info
.
type
.
tensor_type
.
elem_type
,
shape
=
shape
)
value_info
.
type
.
CopyFrom
(
tensor_type
)
if
new_name
is
None
:
self
.
graph
.
output
.
append
(
value_info
)
else
:
from_node
=
self
.
get_from_node
(
name
)
to_nodes
=
self
.
get_to_nodes
(
name
)
for
i
,
output_name
in
enumerate
(
from_node
.
output_names
):
if
output_name
==
name
:
from_node
.
set_output
(
i
,
new_name
)
for
node
in
to_nodes
:
node
.
replace_input
(
name
,
new_name
)
value_info
.
name
=
new_name
self
.
graph
.
output
.
append
(
value_info
)
def
remove_output
(
self
,
name
):
"""根据名称删除模型输出"""
assert
name
in
self
.
output_names
# print("need remove output name:", name)
index
=
None
for
i
,
out
in
enumerate
(
self
.
graph
.
output
):
# print(f"current(index={i}) output name:", out.name)
if
out
.
name
==
name
:
index
=
i
break
else
:
raise
RuntimeError
(
f
"ONNX graphx not has a output named '
{
name
}
'."
)
self
.
graph
.
output
.
pop
(
index
)
def
get_node
(
self
,
name_or_index
:
Union
[
str
,
int
]):
"""根据节点名称或索引获取节点实例"""
if
isinstance
(
name_or_index
,
str
):
if
name_or_index
in
self
.
node_map
:
return
self
.
node_map
.
get
(
name_or_index
,
None
)
elif
isinstance
(
name_or_index
,
int
):
if
name_or_index
<
len
(
self
.
graph
.
node
):
return
self
.
node_map
.
get
(
self
.
graph
.
node
[
name_or_index
].
name
,
None
)
else
:
raise
ValueError
(
f
"Node index
{
name_or_index
}
out of range"
)
def
get_nodes
(
self
,
*
op_types
:
str
):
"""根据节点类型获取节点实例"""
assert
len
(
op_types
)
>=
1
op_types_set
=
set
(
op_types
)
node_names
=
[
node
.
name
for
node
in
self
.
graph
.
node
if
node
.
op_type
in
op_types_set
]
nodes
=
[
self
.
node_map
[
name
]
for
name
in
node_names
]
return
nodes
def
get_initializer
(
self
,
name
:
str
):
"""根据initializer名称获取initializer"""
return
self
.
initializer_map
.
get
(
name
)
def
get_connection
(
self
,
name
:
str
):
"""根据边名称获取边"""
return
self
.
connection_map
.
get
(
name
)
def
get_from_node
(
self
,
conn
:
Union
[
str
,
Connection
]):
"""获取某条边的输入节点名"""
if
isinstance
(
conn
,
str
):
assert
conn
in
self
.
connection_map
,
f
"Connection
{
conn
}
not in connection_map!"
return
self
.
connection_map
[
conn
].
from_node
elif
isinstance
(
conn
,
Connection
):
return
conn
.
from_node
else
:
raise
TypeError
(
f
"Invalid connection type
{
type
(
conn
)
}
"
)
def
get_to_nodes
(
self
,
conn
:
Union
[
str
,
Connection
]):
"""获取某条边的输出节点"""
if
isinstance
(
conn
,
str
):
assert
conn
in
self
.
connection_map
,
f
"Connection
{
conn
}
not in connection_map!"
return
self
.
connection_map
[
conn
].
to_nodes
elif
isinstance
(
conn
,
Connection
):
return
conn
.
to_nodes
else
:
raise
TypeError
(
f
"Invalid connection type
{
type
(
conn
)
}
"
)
def
get_prev_nodes
(
self
,
node
:
Union
[
str
,
Node
]):
"""获取某节点的上游输入节点"""
if
isinstance
(
node
,
str
):
node
=
self
.
node_map
[
node
]
elif
isinstance
(
node
,
Node
):
pass
else
:
raise
TypeError
(
f
"Invalid node type
{
type
(
node
)
}
"
)
nodes
=
[]
for
conn_name
in
node
.
inputs
:
from_node
=
self
.
get_from_node
(
conn_name
)
if
from_node
:
nodes
.
append
(
from_node
)
return
nodes
def
get_next_nodes
(
self
,
node
:
Union
[
str
,
Node
]):
"""获取某节点的下游节点"""
if
isinstance
(
node
,
str
):
node
=
self
.
node_map
[
node
]
elif
isinstance
(
node
,
Node
):
pass
else
:
raise
TypeError
(
f
"Invalid node type
{
type
(
node
)
}
"
)
nodes
=
[]
for
conn_name
in
node
.
outputs
:
to_nodes
=
self
.
get_to_nodes
(
conn_name
)
nodes
.
extend
(
to_nodes
)
return
nodes
def
create_node
(
self
,
op_type
,
op_name
,
inputs
,
outputs
,
doc_string
=
None
,
domain
=
None
,
index
=
None
,
**
attrs
):
"""创建一个新节点"""
onnx_node
=
make_node
(
op_type
,
inputs
,
outputs
,
name
=
op_name
,
doc_string
=
doc_string
,
domain
=
domain
,
**
attrs
)
if
index
is
None
:
self
.
graph
.
node
.
append
(
onnx_node
)
index
=
len
(
self
.
graph
.
node
)
-
1
else
:
assert
index
<=
len
(
self
.
graph
.
node
),
"index out of range"
self
.
graph
.
node
.
insert
(
index
,
onnx_node
)
for
i
in
range
(
index
+
1
,
len
(
self
.
graph
.
node
)):
node_name
=
self
.
graph
.
node
[
i
].
name
old_idx
=
self
.
node_map
[
node_name
].
index
assert
old_idx
==
i
-
1
,
\
f
"Node
{
node_name
}
index conflict:
{
old_idx
}
!=
{
i
-
1
}
"
self
.
node_map
[
node_name
].
index
=
i
new_node
=
Node
(
self
,
self
.
graph
.
node
[
index
],
index
)
self
.
node_map
[
op_name
]
=
new_node
for
in_name
in
new_node
.
input_names
:
if
in_name
not
in
self
.
value_info_map
:
self
.
create_value_info
(
in_name
,
dtype
=
"float"
)
if
in_name
not
in
self
.
connection_map
:
self
.
connection_map
[
in_name
]
=
Connection
(
in_name
,
self
)
self
.
connection_map
[
in_name
].
add_to_node
(
new_node
)
for
out_name
in
new_node
.
output_names
:
if
out_name
not
in
self
.
value_info_map
:
self
.
create_value_info
(
out_name
,
dtype
=
"float"
)
if
out_name
not
in
self
.
connection_map
:
self
.
connection_map
[
out_name
]
=
Connection
(
out_name
,
self
)
self
.
connection_map
[
out_name
].
set_from_node
(
new_node
)
return
new_node
def
create_initializer
(
self
,
name
,
value
:
np
.
ndarray
):
"""创建一个 initializer"""
assert
name
not
in
self
.
initializer_map
,
f
"initializer
{
name
}
already exists!"
init_node
=
numpy_helper
.
from_array
(
value
,
name
=
name
)
use_external_data
=
value
.
nbytes
/
1024
/
1024
/
1024
>
2
if
use_external_data
:
print
(
"use external data:"
,
name
)
init_node
.
data_location
=
onnx
.
TensorProto
.
EXTERNAL
location
=
name
.
replace
(
'/'
,
'+'
)
+
'.data'
onnx
.
external_data_helper
.
set_external_data
(
init_node
,
location
)
with
tempfile
.
TemporaryDirectory
()
as
tmp_dir
:
onnx
.
external_data_helper
.
save_external_data
(
init_node
,
tmp_dir
)
init_node
.
ClearField
(
"raw_data"
)
self
.
graph
.
initializer
.
append
(
init_node
)
onnx
.
external_data_helper
.
load_external_data_for_tensor
(
self
.
graph
.
initializer
[
-
1
],
tmp_dir
)
del
self
.
graph
.
initializer
[
-
1
].
external_data
[:]
self
.
graph
.
initializer
[
-
1
].
ClearField
(
"data_location"
)
else
:
self
.
graph
.
initializer
.
append
(
init_node
)
self
.
initializer_map
[
name
]
=
self
.
graph
.
initializer
[
-
1
]
return
self
.
graph
.
initializer
[
-
1
]
def
create_value_info
(
self
,
name
,
dtype
=
None
,
shape
=
None
):
if
dtype
is
None
:
elem_type
=
None
else
:
assert
isinstance
(
dtype
,
str
)
assert
dtype
in
set
(
SUPPORT_DTYPES
)
elem_type
=
getattr
(
onnx
.
TensorProto
,
dtype
.
upper
())
value_info
=
onnx
.
helper
.
make_tensor_value_info
(
name
=
name
,
elem_type
=
elem_type
,
shape
=
shape
)
self
.
graph
.
value_info
.
append
(
value_info
)
self
.
value_info_map
[
name
]
=
self
.
graph
.
value_info
[
-
1
]
return
self
.
graph
.
value_info
[
-
1
]
def
get_initializer_value
(
self
,
name
):
"""获取initializer的数值"""
init
=
self
.
get_initializer
(
name
)
return
numpy_helper
.
to_array
(
init
)
def
set_initializer_value
(
self
,
name
,
value
:
np
.
ndarray
):
"""为initializer设置新的数值"""
init
=
self
.
get_initializer
(
name
)
# 检查形状和类型
old_shape
=
list
(
init
.
dims
)
new_shape
=
list
(
value
.
shape
)
# old_dtype = TENSOR_TYPE_TO_NP_TYPE.get(init.data_type, None)
old_dtype
=
tensor_dtype_to_np_dtype
(
init
.
data_type
)
new_dtype
=
value
.
dtype
if
old_shape
!=
new_shape
:
warn_message
=
f
"Initailizer
{
name
}
shape changed:
{
old_shape
}
->
{
new_shape
}
"
warnings
.
warn
(
warn_message
,
RuntimeWarning
)
if
old_dtype
is
not
None
and
old_dtype
!=
new_dtype
:
warn_message
=
f
"Initailizer
{
name
}
dtype changed:
{
old_dtype
}
->
{
new_dtype
}
"
warnings
.
warn
(
warn_message
,
RuntimeWarning
)
new_tensor_proto
=
numpy_helper
.
from_array
(
value
,
name
=
name
)
init
.
CopyFrom
(
new_tensor_proto
)
def
connect_node
(
self
,
node
,
inputs_map
,
outputs_map
):
"""将某个节点与其上下游节点连接起来
Args:
node: Node
inputs_map: [(node0, out_idx0), (node1, out_idx1), ...]
outputs_map: [(node0, in_idx0), (node1, in_idx1), ...]
"""
# 在连接 A -> B 时,若 A 的输出名与 B 的输入名冲突时,优先使用 A 的输出名,
# 即:B.input[i] = A.output[j]
for
i
,
(
n
,
j
)
in
enumerate
(
inputs_map
):
if
isinstance
(
n
,
str
):
n
=
self
.
node_map
[
n
]
assert
j
<
len
(
n
.
outputs
),
\
f
"output index
{
i
}
out of node
{
n
.
name
}
outputs range"
node
.
set_input
(
i
,
n
.
outputs
[
j
])
for
name
,
(
n
,
i
)
in
zip
(
node
.
outputs
,
outputs_map
):
if
isinstance
(
n
,
str
):
n
=
self
.
node_map
[
n
]
assert
i
<
len
(
n
.
outputs
),
\
f
"output index
{
i
}
out of node
{
n
.
name
}
outputs range"
n
.
set_output
(
i
,
name
)
# TODO: update self.connection_map
def
pop_node
(
self
,
node
:
Union
[
str
,
Node
,
int
],
auto_connect
=
True
):
"""根据节点名称或索引移除节点"""
if
isinstance
(
node
,
str
):
node
=
self
.
node_map
.
get
(
node
,
None
)
if
node
is
None
:
return
None
index
=
node
.
index
assert
node
.
name
==
self
.
graph
.
node
[
index
].
name
elif
isinstance
(
node
,
int
):
if
node
>=
len
(
self
.
graph
.
node
):
raise
ValueError
(
f
"Node index
{
node
}
out of range"
)
index
=
node
elif
isinstance
(
node
,
Node
):
index
=
node
.
index
else
:
raise
ValueError
(
f
"Invalid node name or index:
{
node
}
"
)
for
i
in
range
(
index
+
1
,
len
(
self
.
graph
.
node
)):
node
=
self
.
graph
.
node
[
i
]
self
.
node_map
[
node
.
name
].
index
-=
1
# print(f"node_name={self.graph.node[index].name} node_index={index}")
_node_obj
=
self
.
graph
.
node
[
index
]
_node
=
self
.
get_node
(
_node_obj
.
name
)
next_nodes
=
self
.
get_next_nodes
(
_node
)
self
.
graph
.
node
.
pop
(
index
)
self
.
node_map
.
pop
(
_node_obj
.
name
)
# automatic connecting edges
if
auto_connect
and
len
(
_node
.
inputs
)
==
1
and
len
(
_node
.
outputs
)
==
1
:
# self.connection_map[_node.inputs[0]].pop_to_node(_node)
for
next_node
in
next_nodes
:
next_node
.
replace_input
(
_node
.
outputs
[
0
],
_node
.
inputs
[
0
])
# self.connection_map[_node.inputs[0]].add_to_node(next_node)
# self.connection_map.pop(_node.outputs[0])
# update connection_map
for
in_name
in
_node
.
input_names
:
if
_node
.
name
in
self
.
connection_map
[
in_name
].
to_node_names
:
self
.
connection_map
[
in_name
].
pop_to_node
(
_node
)
for
i
,
out_name
in
enumerate
(
_node
.
output_names
):
self
.
connection_map
[
out_name
].
clear_from_node
()
return
_node
def
remove_nodes
(
self
,
nodes
:
List
[
str
|
Node
],
auto_connect
=
False
):
"""同时删除多个节点"""
indices
=
set
()
_nodes
=
[]
invalid_nodes
=
set
()
for
node
in
nodes
:
if
isinstance
(
node
,
str
):
if
node
in
self
.
node_map
:
node
=
self
.
node_map
[
node
]
if
node
.
index
not
in
indices
:
_nodes
.
append
(
node
)
indices
.
add
(
node
.
index
)
else
:
invalid_nodes
.
add
(
node
)
elif
isinstance
(
node
,
Node
):
if
node
.
index
not
in
indices
:
_nodes
.
append
(
node
)
indices
.
add
(
node
.
index
)
else
:
invalid_nodes
.
add
(
node
)
_nodes
.
sort
(
key
=
lambda
x
:
x
.
index
,
reverse
=
True
)
use_progress_bar
=
len
(
_nodes
)
>
500
if
use_progress_bar
:
pbar
=
tqdm
.
tqdm
(
total
=
len
(
_nodes
),
desc
=
"Removing nodes"
)
for
node
in
_nodes
:
self
.
pop_node
(
node
,
auto_connect
=
auto_connect
)
if
use_progress_bar
:
pbar
.
update
(
1
)
if
use_progress_bar
:
pbar
.
close
()
# print(f"{len(nodes) - len(invalid_nodes)} nodes have been removed.")
# if len(invalid_nodes) > 0:
# print(f"find {len(invalid_nodes)} invalid nodes:\n", invalid_nodes)
def
pop_initializer
(
self
,
init_name
:
str
,
update_node_inputs
:
bool
=
True
):
"""根据initializer名字移除initializer"""
_init1
=
self
.
initializer_map
.
pop
(
init_name
)
init_index
=
None
for
i
in
range
(
len
(
self
.
graph
.
initializer
)):
if
self
.
graph
.
initializer
.
name
==
init_name
:
init_index
=
i
break
else
:
raise
ValueError
(
f
"Not existing a Initializer named
{
init_name
}
"
)
_init2
=
self
.
graph
.
initializer
.
pop
(
init_index
)
assert
id
(
_init1
)
==
id
(
_init2
)
# if update_node_inputs and init_name in self.connection_map:
# to_nodes = self.get_to_nodes(init_name)
# self.connection_map.pop(init_name)
# for node in to_nodes:
# num_inputs = len(node.inputs)
# for i in range(num_inputs-1, -1, -1):
# if node.inputs[i] == init_name:
# node.inputs.pop(i)
return
_init1
def
update_map
(
self
):
"""更新connection_map与node_map"""
self
.
node_map
.
clear
()
self
.
connection_map
.
clear
()
self
.
initializer_map
.
clear
()
self
.
sparse_initializer_map
.
clear
()
self
.
value_info_map
.
clear
()
for
i
,
node
in
enumerate
(
self
.
graph
.
node
):
new_node
=
Node
(
self
,
node
,
i
)
self
.
add_node_name_if_nameless
(
new_node
)
self
.
node_map
[
node
.
name
]
=
new_node
for
conn_name
in
node
.
input
:
if
conn_name
not
in
self
.
connection_map
:
self
.
connection_map
[
conn_name
]
=
Connection
(
conn_name
,
self
)
self
.
connection_map
[
conn_name
].
add_to_node
(
new_node
)
for
conn_name
in
node
.
output
:
if
conn_name
not
in
self
.
connection_map
:
self
.
connection_map
[
conn_name
]
=
Connection
(
conn_name
,
self
)
self
.
connection_map
[
conn_name
].
set_from_node
(
new_node
)
for
i
,
node
in
enumerate
(
self
.
graph
.
initializer
):
self
.
initializer_map
[
node
.
name
]
=
node
for
i
,
node
in
enumerate
(
self
.
graph
.
sparse_initializer
):
self
.
sparse_initializer_map
[
node
.
name
]
=
[
node
,
i
]
for
i
,
conn
in
enumerate
(
self
.
graph
.
value_info
):
self
.
value_info_map
[
conn
.
name
]
=
conn
def
find_unuseful_nodes
(
self
):
"""寻找没有用到的节点"""
end_names
=
set
()
for
output_name
in
self
.
output_names
:
end_names
.
add
(
self
.
get_from_node
(
output_name
).
name
)
unuseful_names
=
set
()
for
node
in
self
.
node_map
.
values
():
if
node
.
name
in
end_names
:
continue
next_nodes
=
self
.
get_next_nodes
(
node
)
if
len
(
next_nodes
)
==
0
:
unuseful_names
.
add
(
node
.
name
)
model_output_names
=
set
(
self
.
output_names
)
q
=
deque
([
self
.
node_map
[
name
]
for
name
in
unuseful_names
])
while
len
(
q
)
!=
0
:
node
=
q
.
popleft
()
prev_nodes
=
self
.
get_prev_nodes
(
node
)
for
node1
in
prev_nodes
:
next_nodes
=
self
.
get_next_nodes
(
node1
)
next_names
=
set
([
node2
.
name
for
node2
in
next_nodes
])
# if (next_names - end_names).issubset(unuseful_names):
if
next_names
.
issubset
(
unuseful_names
):
if
node1
.
name
not
in
unuseful_names
and
set
(
node1
.
output_names
).
isdisjoint
(
model_output_names
):
q
.
append
(
node1
)
unuseful_names
.
add
(
node1
.
name
)
unuseful_nodes
=
[
self
.
node_map
[
name
]
for
name
in
unuseful_names
]
return
unuseful_nodes
def
remove_trash
(
self
):
"""
1. 移除无用的节点
2. 移除无用的initializer
3. 移除没有输入节点的connection
4. 移除没有用到的模型输入与输出
5. 移除没有用到的value_info
"""
self
.
update_map
()
unuseful_nodes
=
self
.
find_unuseful_nodes
()
print
(
f
"Find unuseful
{
len
(
unuseful_nodes
)
}
nodes!"
)
for
i
,
node
in
enumerate
(
unuseful_nodes
):
print
(
f
"remove unuseful node
{
i
}
:"
,
node
.
name
)
self
.
remove_nodes
(
unuseful_nodes
)
self
.
update_map
()
all_node_inputs
=
set
()
for
node
in
self
.
node_map
.
values
():
all_node_inputs
.
update
(
node
.
input_names
)
# remove unuseful initializers
cnt
=
0
for
init_name
in
list
(
self
.
initializer_map
.
keys
()):
if
init_name
in
all_node_inputs
:
continue
index
=
None
for
i
,
init
in
enumerate
(
self
.
graph
.
initializer
):
if
init
.
name
==
init_name
:
index
=
i
break
else
:
raise
ValueError
(
f
"
{
init_name
}
not in model.graph.initializer"
)
print
(
f
"remove unuseful initializer
{
cnt
}
:"
,
init_name
)
self
.
graph
.
initializer
.
pop
(
index
)
cnt
+=
1
# remove unuseful sparse_initializers
cnt
=
0
for
init_name
in
list
(
self
.
sparse_initializer_map
.
keys
()):
if
init_name
in
all_node_inputs
:
continue
index
=
None
for
i
,
init
in
enumerate
(
self
.
graph
.
sparse_initializer
):
if
init
.
name
==
init_name
:
index
=
i
break
else
:
raise
ValueError
(
f
"
{
init_name
}
not in model.graph.sparse_initializer"
)
print
(
f
"remove unuseful sparse initializer
{
cnt
}
:"
,
init_name
)
self
.
graph
.
sparse_initializer
.
pop
(
index
)
cnt
+=
1
self
.
update_map
()
# remove unuseful inputs and outputs
for
in_name
in
self
.
input_names
:
# print(in_name, [n.name for n in self.get_to_nodes(in_name)])
if
len
(
self
.
get_to_nodes
(
in_name
))
!=
0
:
continue
for
i
,
_in
in
enumerate
(
self
.
graph
.
input
):
if
in_name
==
_in
.
name
:
self
.
graph
.
input
.
pop
(
i
)
break
for
out_name
in
self
.
output_names
:
# print(out_name, self.get_from_node(out_name).name)
if
self
.
get_from_node
(
out_name
)
is
not
None
:
continue
for
i
,
_out
in
enumerate
(
self
.
graph
.
output
):
if
out_name
==
_out
.
name
:
self
.
graph
.
output
.
pop
(
i
)
break
self
.
update_map
()
# remove unuseful value_info
cnt
=
0
num_value_info
=
len
(
self
.
graph
.
value_info
)
for
i
in
range
(
num_value_info
-
1
,
-
1
,
-
1
):
v
=
self
.
graph
.
value_info
[
i
]
if
v
.
name
not
in
self
.
connection_map
:
self
.
graph
.
value_info
.
pop
(
i
)
print
(
f
"remove unuseful value_info
{
cnt
}
:"
,
v
.
name
)
cnt
+=
1
self
.
update_map
()
def
infer_shape
(
self
,
strict_mode
=
False
):
for
vi
in
self
.
graph
.
value_info
:
if
vi
.
type
.
HasField
(
"tensor_type"
):
vi
.
type
.
tensor_type
.
ClearField
(
"shape"
)
model
=
shape_inference
.
infer_shapes
(
self
.
model
,
strict_mode
=
strict_mode
)
self
.
model
=
model
self
.
domain
=
model
.
domain
self
.
graph
=
model
.
graph
self
.
ir_version
=
model
.
ir_version
self
.
mdoel_version
=
model
.
model_version
self
.
opset_import
=
model
.
opset_import
self
.
update_map
()
def
infer_node_shpe
(
self
,
node
):
input_shapes
=
[]
input_dtypes
=
[]
for
input_name
in
node
.
inputs
:
value_info
=
self
.
value_info_map
[
input_name
]
input_shapes
.
append
(
value_info
.
type
.
tensor_type
.
dims
)
input_dtypes
.
append
(
value_info
.
type
.
tensor_type
.
type
)
shape_inference
.
infer_node_outputs
(
node
.
obj
,
input_shapes
,
input_dtypes
)
def
convert_float_to_float16
(
self
):
self
.
model
=
float16
.
convert_float_to_float16
(
self
.
model
,
keep_io_types
=
True
)
def
save
(
self
,
save_path
,
save_as_external_data
=
False
,
all_tensors_to_one_file
=
True
):
self
.
remove_trash
()
external_data_name
=
osp
.
basename
(
save_path
)
+
'.data'
external_data_path
=
osp
.
join
(
osp
.
dirname
(
save_path
),
external_data_name
)
if
save_as_external_data
and
osp
.
isfile
(
external_data_path
):
os
.
remove
(
external_data_path
)
onnx
.
save
(
self
.
model
,
save_path
,
save_as_external_data
=
save_as_external_data
,
all_tensors_to_one_file
=
all_tensors_to_one_file
,
location
=
external_data_name
,
size_threshold
=
1024
,
convert_attribute
=
False
)
migraphx_infer/onnx_sim.py
0 → 100644
View file @
a1865640
import
onnx
from
onnxsim
import
simplify
from
onnxconverter_common
import
float16
onnx_model_path
=
"./weights/ground.onnx"
sim_model_path
=
"./weights/ground_sim.onnx"
print
(
"1️⃣ 正在进行 ONNX Simplify..."
)
model
=
onnx
.
load
(
onnx_model_path
)
model_simp
,
check
=
simplify
(
model
)
if
check
:
onnx
.
save
(
model_simp
,
sim_model_path
)
print
(
f
"✅ Simplify 完成!已保存至
{
sim_model_path
}
"
)
else
:
print
(
"❌ Simplify 验证失败!"
)
exit
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment