"csrc/gfx93/decode/sparse_fp8/splitkv_mla.cuh" did not exist on "6d68e3d119d129ed3f01cafe82f0a3dd8933b52d"
Commit 7bc21d37 authored by zk's avatar zk
Browse files

Update-MIGraphX-optimization-workflow

parent 3191f720
......@@ -202,18 +202,67 @@ python onnx_inference_deform_optim.py
## 7\. migraphx推理
1. 进入migraphx_infer文件夹
1. 环境准备(http://42.228.13.241:10068/wangwf/groundingdino)
* 下载 DTK-26.04-txpl-temp-0312-ubuntu20.04-x86_64(http://112.11.77.146:65182/jenkins/rocm/26.04/intel/ubuntu20.04/DTK-26.04-txpl-temp-0312-ubuntu20.04-x86_64.tar.gz) ,解压后替换掉原 /opt/dtk 目录。
```bash
tar -zxvf DTK-26.04-txpl-temp-0312-ubuntu20.04-x86_64.tar.gz -C /opt/
rm /opt/dtk # 删除原来的软链接
ln -s /opt/dtk-26.04-txpl-temp-0312 /opt/dtk # 创建新的软链接
```
* 替换hipdnn:
```bash
tar -zxvf package_resize.tar.gz
cd package_resize
cp -r install/lib/hipdnn_plugins /opt/dtk/lib/
cp -r install/lib/libhipdnn_backend.so /opt/dtk/lib/
rm -rf /opt/dtk/include/hipdnn
cp -r install/include/hipdnn /opt/dtk/include/hipdnn
rm -rf /opt/dtk/lib/cmake/hipdnn*
cp -r install/lib/cmake/* /opt/dtk/lib/cmake/
rm -f /opt/dtk/lib/hipdnn_plugins/engines/libmiopen_legacy_plugin.so
```
* 激活dtk
```bash
source /opt/dtk/env.sh
export LD_LIBRARY_PATH=/usr/local/lib/python3.10/dist-packages/torch/lib:${LD_LIBRARY_PATH}
```
* 安装migraphx
```bash
chmod +x ./migraphx-5.1.2+das.opt1.ab9210b.dtk2604-cp310-cp310-manylinux_2_35_x86_64.run
./migraphx-5.1.2+das.opt1.ab9210b.dtk2604-cp310-cp310-manylinux_2_35_x86_64.run
```
* 模型优化
```bash
cd weights
onnxsim ground.onnx ground_sim.onnx
cd migraphx_infer
python modify_onnx_0601.py ../weights/ground_sim.onnx ../weights/ground_opt.onnx
```
2. 运行转换onnx脚本
将简化后的onnx转换为要用migraphx推理的onnx(ground_sim.onnx->ground_opt.onnx)
2. 性能测试(编译加运行)
```bash
export MIGRAPHX_ENABLE_GRAPHAPI_REDUCTION=1
export MIGRAPHX_ENABLE_LAYERNORM_FUSION=1
migraphx-driver perf --onnx ground_opt.onnx --fp16 --output ground_opt.mxr
```
或者进入migraphx_infer文件夹,运行
```bash
cd migraphx_infer
bash migraphx_export.bash
```
3. 如果已经得到了mxr文件,直接测试
3. 如果已经得到了mxr文件,也可以直接测试
```bash
bash migraphx_perf.bash
```
......
......@@ -45,12 +45,11 @@ text_token_mask = torch.tensor([[[True, False, False, False],
[False, True, True, False],
[False, False, False, True]]]).to(device)
# img = torch.randn(1, 3, 800, 1200).to(device)
img = torch.randn(1, 3, 400, 600).to(device)
img = torch.randn(1, 3, 800, 1200).to(device)
# img = torch.randn(1, 3, 400, 600).to(device)
# 导出 ONNX
# onnx_output_path = "../weights/ground_deform.onnx"
onnx_output_path = "../weights_400x600/ground_deform.onnx"
onnx_output_path = "../weights/ground_deform.onnx"
torch.onnx.export(
model,
......
......@@ -17,7 +17,7 @@ from PIL import Image
"""
so_options = ort.SessionOptions()
custom_op_lib_path = "../ort_plugin_fp16_C/build/libms_deform_attn_ort.so"
custom_op_lib_path = "../ort_plugin/build/libms_deform_attn_ort.so"
so_options.register_custom_ops_library(custom_op_lib_path)
# 开启ort优化
so_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
......@@ -195,7 +195,7 @@ def benchmark_performance(
if __name__ == '__main__':
# 配置参数
model_path = '../weights/ground_deform_fp16_all.onnx'
model_path = '../weights/ground_deform_sim_fp16.onnx'
"""
../weights/ground_deform.onnx 普通版本
../weights/ground_deform_sim.onnx 简化版本
......
......@@ -3,25 +3,25 @@ from onnxsim import simplify
from onnxconverter_common import float16
onnx_model_path = "../weights/ground_deform.onnx"
sim_model_path = "../weights_opt/ground_deform_opt.onnx"
fp16_model_path = "../weights_opt/ground_deform_opt_fp16.onnx"
fp16_all_model_path = "../weights_opt/ground_deform_opt_fp16_all.onnx"
sim_model_path = "../weights/ground_deform_sim.onnx"
fp16_model_path = "../weights/ground_deform_sim_fp16.onnx"
# fp16_all_model_path = "../weights/ground_deform_opt_fp16_all.onnx"
custom_op_lib_path = "../ort_plugin_fp16/build/libms_deform_attn_ort.so"
custom_op_lib_path = "../ort_plugin/build/libms_deform_attn_ort.so"
# # ==========================================
# # 第一步:ONNX Simplify (附带自定义算子库)
# # ==========================================
# print("1️⃣ 正在进行 ONNX Simplify...")
# model = onnx.load(onnx_model_path)
# model_simp, check = simplify(model, custom_lib=custom_op_lib_path)
print("1️⃣ 正在进行 ONNX Simplify...")
model = onnx.load(onnx_model_path)
model_simp, check = simplify(model, custom_lib=custom_op_lib_path)
# if check:
# onnx.save(model_simp, sim_model_path)
# print(f"✅ Simplify 完成!已保存至 {sim_model_path}")
# else:
# print("❌ Simplify 验证失败!")
# exit()
if check:
onnx.save(model_simp, sim_model_path)
print(f"✅ Simplify 完成!已保存至 {sim_model_path}")
else:
print("❌ Simplify 验证失败!")
exit()
......@@ -47,12 +47,12 @@ print(f"✅ FP16 转换完成(避开自定义算子)!已保存至 {fp16_model_
print("\n2️⃣ 正在进行纯 FP16 精度转换...")
model_fp16_all = float16.convert_float_to_float16(
model_to_fp16,
node_block_list=original_cast_nodes, # 保护所有原生的 Cast 节点
keep_io_types=True # 保持整个模型的总输入/输出还是 FP32
)
onnx.save(model_fp16_all, fp16_all_model_path)
print(f"✅ 纯 FP16 转换完成!已保存至 {fp16_all_model_path}")
# print("\n2️⃣ 正在进行纯 FP16 精度转换...")
# model_fp16_all = float16.convert_float_to_float16(
# model_to_fp16,
# node_block_list=original_cast_nodes, # 保护所有原生的 Cast 节点
# keep_io_types=True # 保持整个模型的总输入/输出还是 FP32
# )
# onnx.save(model_fp16_all, fp16_all_model_path)
# print(f"✅ 纯 FP16 转换完成!已保存至 {fp16_all_model_path}")
......@@ -82,25 +82,6 @@ at::Tensor ms_deform_attn_cuda_forward(
return output;
}
// at::Tensor ms_deform_attn_forward_wrapper(
// const at::Tensor &value,
// const at::Tensor &spatial_shapes,
// const at::Tensor &level_start_index,
// const at::Tensor &sampling_loc,
// const at::Tensor &attn_weight,
// int64_t im2col_step // ✅ 注意这里
// )
// {
// return groundingdino::ms_deform_attn_cuda_forward(
// value,
// spatial_shapes,
// level_start_index,
// sampling_loc,
// attn_weight,
// im2col_step
// );
// }
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
......@@ -173,16 +154,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
};
}
} // namespace groundingdino
// #include <torch/library.h>
// // 注册 schema
// TORCH_LIBRARY(my_ops, m) {
// m.def("ms_deform_attn(Tensor value, Tensor spatial_shapes, Tensor level_start_index, Tensor sampling_loc, Tensor attn_weight, int im2col_step) -> Tensor");
// }
// // CUDA实现
// TORCH_LIBRARY_IMPL(my_ops, CUDA, m) {
// m.impl("ms_deform_attn", groundingdino::ms_deform_attn_forward_wrapper);
// }
\ No newline at end of file
} // namespace groundingdino
\ No newline at end of file
==================================================
Grounding DINO 性能测试报告
==================================================
测试时间: 2026-06-01 13:39:35
测试设备: GPU
GPU型号: K100_AI
预热次数: 5
测试次数: 10
平均推理时延: 186.29 ms
时延标准差: 0.69 ms
最大时延: 188.28 ms
最小时延: 185.77 ms
平均FPS: 5.37 帧/秒
单次推理时延(最后一次): 186.12 ms
images/out/result.jpg

1.35 MB | W: | H:

images/out/result.jpg

1.35 MB | W: | H:

images/out/result.jpg
images/out/result.jpg
images/out/result.jpg
images/out/result.jpg
  • 2-up
  • Swipe
  • Onion skin
CUDA_VISIBLE_DEVICES=1 python demo/infer_torch.py \
HIP_VISIBLE_DEVICES=4 python demo/infer_torch.py \
-c groundingdino/config/GroundingDINO_SwinB_cfg.py \
-p weights/groundingdino_swinb_cogcoor.pth \
-i images/in/car_1.jpg \
......
......@@ -3,9 +3,9 @@ MIGRAPHX_ENABLE_GRAPHAPI_REDUCTION=1
MIGRAPHX_ENABLE_LAYERNORM_FUSION=1
migraphx-driver perf --onnx \
../test0525/ground_opt_0509.onnx \
../weights/ground_opt_0601.onnx \
--fp16 \
--output \
../test0525/ground_opt_0515.mxr
../weights/ground_opt_0601.mxr
# ../weights/ground_opt_0430.mxr > migraphx_log.log 2>&1
\ No newline at end of file
......@@ -275,8 +275,8 @@ def benchmark_performance(
# =========================
if __name__ == "__main__":
model_path = "../weights/ground_opt_0430.onnx"
cache_path = "../weights/ground_opt_0515_1.mxr"
model_path = "../weights/ground_opt_0601.onnx"
cache_path = "../weights/ground_opt_0601.mxr"
img_path = "../images/in/car_1.jpg"
BOX_TRESHOLD = 0.35
......
migraphx-driver perf --batch 1 \
-n 10 \
--fp16 \
--migraphx ../weights/ground_opt_0515_1.mxr
\ No newline at end of file
--migraphx ../weights/ground_opt_0601.mxr
\ No newline at end of file
import sys
import numpy as np
from onnx_modifier import ONNXModifier
def change_inf_to_value(om: ONNXModifier):
records = set()
for where_node in om.get_nodes("Where"):
for input_name in where_node.inputs[1:]:
init = om.get_initializer(input_name)
if init is None:
continue
assert input_name == init.name
init_name = input_name
if init_name in records:
continue
# info = np.finfo(np.float32)
info = np.finfo(np.float16)
data = om.get_initializer_value(init.name)
if data.size > 1:
continue
if data == np.inf:
om.set_initializer_value(init_name, np.array(info.max, dtype=np.float32))
elif data == -np.inf:
om.set_initializer_value(init_name, np.array(info.min, dtype=np.float32))
else:
continue
# print("Changed value:", init_name)
records.add(init_name)
def optimize_where_ndoes(om: ONNXModifier):
"""Where节点等价替换
(1) condition为initializer, X为0, Y为输入数据:
Where(cond, X, Y) ==> Mul(Y, ~cond)
(2) condition为initializer, X为负无穷, Y为输入数据
Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
(3) condition为真实输入, X为负无穷, Y为输入数据
Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
cases:
1. Where(cond, -inf, input)
a. /transformer/encoder/fusion_layers.*/attn/Where
b. /transformer/encoder/fusion_layers.*/attn/Where_1
c. /class_embed.0_*/Where: Where(cond, -inf, input)
2. Where(cond, 0, input):
a. /transformer/encoder/layers.*/self_attn/Where
b. /transformer/decoder/layers.*/cross_attn/Where
"""
for where_node in om.get_nodes("Where"):
where_name = where_node.name
# print("Process where node:", where_name)
x_value = om.get_initializer_value(where_node.inputs[1])
assert x_value.size == 1
assert x_value == np.array(0.0, dtype=np.float32) or \
x_value == np.array(-np.inf, dtype=np.float32)
cond_init = om.get_initializer(where_node.inputs[0])
if cond_init is not None:
cond_value = om.get_initializer_value(where_node.inputs[0])
if x_value == np.array(0.0, dtype=np.float32):
# Where(cond, X, Y) ==> Mul(Y, ~cond)
mul_name = where_name.replace("Where", "NewMul")
mul_b_init = om.create_initializer(mul_name + "_B",
(~cond_value).astype(np.float32))
mul_node = om.create_node("Mul",
mul_name,
[where_node.inputs[2], mul_b_init.name],
[mul_name+"_output_0"],
index=where_node.index)
next_nodes = where_node.next_nodes
for next_node in next_nodes:
next_node.replace_input(where_node.outputs[0], mul_node.outputs[0])
elif x_value == np.array(-np.inf, dtype=np.float32):
# Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
sub_name = where_name.replace("Where", "NewSub")
sub_b_init = om.create_initializer(
sub_name + "_B",
np.where(cond_value.astype(np.float32),
np.finfo(np.float16).max, 0.0).astype(np.float32)
)
sub_node = om.create_node("Sub",
sub_name,
[where_node.inputs[2], sub_b_init.name],
[sub_name+"_output_0"],
index=where_node.index)
next_nodes = where_node.next_nodes
for next_node in next_nodes:
next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
else:
# Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
assert x_value == np.array(-np.inf, dtype=np.float32)
cast_name = where_name.replace("Where", "NewCast")
mul_name = where_name.replace("Where", "NewMul")
sub_name = where_name.replace("Where", "NewSub")
cast_node = om.create_node("Cast",
cast_name,
[where_node.inputs[0]],
[cast_name+"_output_0"],
to=1,
index=where_node.index)
mul_b_init = om.create_initializer(mul_name + "_B",
np.array([np.finfo(np.float16).max], np.float32))
mul_node = om.create_node("Mul",
mul_name,
[cast_node.outputs[0], mul_b_init.name],
[mul_name+"_output_0"],
index=cast_node.index+1)
sub_node = om.create_node("Sub",
sub_name,
[where_node.inputs[2], mul_node.outputs[0]],
[sub_name+"_output_0"],
index=mul_node.index+1)
next_nodes = where_node.next_nodes
for next_node in next_nodes:
next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
om.update_map()
def optimize_transpose_nodes(om: ONNXModifier):
transpose_list = [
"/transformer/encoder/Transpose",
"/transformer/encoder/Transpose_1",
"/transformer/encoder/Transpose_2",
"/transformer/encoder/Transpose_3",
"/transformer/encoder/Transpose_4",
"/transformer/encoder/Transpose_5",
"/transformer/encoder/Transpose_6",
"/transformer/encoder/Transpose_7",
"/transformer/encoder/Transpose_8",
"/transformer/encoder/Transpose_9",
"/transformer/encoder/Transpose_10",
"/transformer/encoder/Transpose_11",
"/transformer/decoder/layers.0/Transpose",
"/transformer/decoder/layers.0/Transpose_1",
"/transformer/decoder/layers.0/Transpose_2",
"/transformer/decoder/layers.1/Transpose",
"/transformer/decoder/layers.1/Transpose_1",
"/transformer/decoder/layers.1/Transpose_2",
"/transformer/decoder/layers.2/Transpose",
"/transformer/decoder/layers.2/Transpose_1",
"/transformer/decoder/layers.2/Transpose_2",
"/transformer/decoder/layers.3/Transpose",
"/transformer/decoder/layers.3/Transpose_1",
"/transformer/decoder/layers.3/Transpose_2",
"/transformer/decoder/layers.4/Transpose",
"/transformer/decoder/layers.4/Transpose_1",
"/transformer/decoder/layers.4/Transpose_2",
"/transformer/decoder/layers.5/Transpose",
"/transformer/decoder/layers.5/Transpose_1",
"/transformer/decoder/layers.5/Transpose_2",
"/transformer/Transpose_8",
"/transformer/decoder/Transpose",
"/transformer/decoder/Transpose_1",
"/transformer/decoder/Transpose_2",
"/transformer/decoder/Transpose_3",
"/transformer/decoder/Transpose_4",
"/transformer/decoder/Transpose_5",
"/transformer/decoder/Transpose_6",
"/transformer/decoder/Transpose_7",
"/transformer/decoder/Transpose_8",
"/transformer/decoder/Transpose_9",
"/transformer/decoder/Transpose_10",
"/transformer/decoder/Transpose_11"
]
for name in transpose_list:
node = om.get_node(name)
assert node.attrs['perm'] == [1, 0 , 2] or node.attrs['perm'] == [1, 0 , 2, 3], \
f"perm={node.attrs['perm']}"
next_nodes = om.get_next_nodes(node)
for node_ in next_nodes:
node_.replace_input(node.outputs[0], node.inputs[0])
# modify /transformer/encoder/text_layers.*/self_attn/Reshape_4
# om.set_initializer_value("_v_8735", np.array([-1, 4, 256], np.int64))
shape_init1 = om.create_initializer(
"/transformer/encoder/text_layers.x/self_attn/des_shape",
np.array([1, 4, 256], np.int64)
)
for i in range(6):
reshape_node = om.get_node(f"/transformer/encoder/text_layers.{i}/self_attn/Reshape_4")
reshape_node.set_input(1, shape_init1.name)
# modify /transformer/enc_out_class_embed/Transpose
om.get_node("/transformer/enc_out_class_embed/Transpose").set_attribute("perm", [0, 2, 1])
# modify /transformer/decoder/Reshape_*
om.set_initializer_value("_v_5525", np.array([1, 900, -1], np.int64))
# modify /transformer/decoder/layers.*/self_attn/Reshape_4
# modify /transformer/decoder/layers.*/ca_text/Reshape_6
# om.set_initializer_value("_v_6230", np.array([-1, 900, 256], np.int64))
shape_init3 = om.create_initializer(
"/transformer/decoder/layers.x/self_attn_ca_text/des_shape",
np.array([1, 900, 256], np.int64)
)
for i in range(6):
reshape_node1 = om.get_node(f"/transformer/decoder/layers.{i}/self_attn/Reshape_4")
reshape_node1.set_input(1, shape_init3.name)
reshape_node2 = om.get_node(f"/transformer/decoder/layers.{i}/ca_text/Reshape_6")
reshape_node2.set_input(1, shape_init3.name)
# modify /transformer/decoder/layers.0/Add
# modify /transformer/decoder/layers.0/Add_1
init_name = "/transformer/Tile_1_output_0"
add_value = om.get_initializer_value(init_name)
om.set_initializer_value(init_name, np.ascontiguousarray(add_value.transpose(1, 0, 2)))
om.update_map()
om.infer_shape()
def optmize_sin_cos_block(om: ONNXModifier):
node_pairs = [
("/transformer/decoder/Gather_1", "/transformer/decoder/ref_point_head/layers.0/MatMul"),
("/transformer/decoder/Gather_6", "/transformer/decoder/ref_point_head/layers.0_1/MatMul"),
("/transformer/decoder/Gather_11", "/transformer/decoder/ref_point_head/layers.0_2/MatMul"),
("/transformer/decoder/Gather_16", "/transformer/decoder/ref_point_head/layers.0_3/MatMul"),
("/transformer/decoder/Gather_21", "/transformer/decoder/ref_point_head/layers.0_4/MatMul"),
("/transformer/decoder/Gather_26", "/transformer/decoder/ref_point_head/layers.0_5/MatMul"),
]
unsqueeze_axes_init1 = om.create_initializer(
"/transformer/decoder/sin_cos_block/unsqueeze_axes1",
np.array([3, 4], np.int64)
)
slice_axes_init = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_axes",
np.array([4], np.int64)
)
slice_steps_init = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_steps",
np.array([1], np.int64)
)
slice_starts_init1 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_starts1",
np.array([0], np.int64)
)
slice_ends_init1 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_ends1",
np.array([1], np.int64)
)
slice_starts_init2 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_steps2",
np.array([1], np.int64)
)
slice_ends_init2 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_ends2",
np.array([2], np.int64)
)
reshape_init = om.create_initializer(
"/transformer/decoder/sin_cos_block/reshape_dst_shape",
np.array([1, 900, -1], np.int64)
)
for i, (gather_name, matmul_name) in enumerate(node_pairs):
gather_node = om.get_node(gather_name)
next_node = om.get_next_nodes(gather_node)[0]
assert next_node.op_type == "Mul", f"{next_node.op_type} {next_node.name}"
mul_init_value = om.get_initializer_value(next_node.inputs[1])
assert mul_init_value.size == 1
next_node = om.get_next_nodes(next_node)[0]
assert next_node.op_type == "Unsqueeze"
next_node.set_inputs([gather_node.inputs[0], unsqueeze_axes_init1.name])
next_node = om.get_next_nodes(next_node)[0]
assert next_node.op_type == "Div"
div_init_value = om.get_initializer_value(next_node.inputs[1])
new_value = (div_init_value / mul_init_value).reshape(1, 1, 1, 64, 2)
new_init = om.create_initializer(next_node.name + "_B", new_value)
next_node.set_input(1, new_init.name)
next_nodes = om.get_next_nodes(next_node)
assert len(next_nodes) == 2 and all(x.op_type == 'Slice' for x in next_nodes)
sin_node, cos_node = None, None
for j, slice_node in enumerate(next_nodes):
slice_node.set_inputs([slice_node.inputs[0],
slice_starts_init1.name if j == 0 else slice_starts_init2.name,
slice_ends_init1.name if j == 0 else slice_ends_init2.name,
slice_axes_init.name,
slice_steps_init.name])
next_node = om.get_next_nodes(slice_node)[0]
if next_node.op_type == "Sin":
sin_node = next_node
elif next_node.op_type == "Cos":
cos_node = next_node
else:
raise RuntimeError("match fail!")
next_node = om.get_next_nodes(next_node)[0]
assert next_node.op_type == "Unsqueeze"
next_node = om.get_next_nodes(next_node)[0]
assert next_node.op_type == "Concat"
next_node.set_inputs([sin_node.outputs[0], cos_node.outputs[0]])
next_node.set_attribute("axis", 4)
next_node = om.get_next_nodes(next_node)[0]
assert next_node.op_type == "Reshape"
next_node.set_input(1, reshape_init.name)
matmul_node = om.get_node(matmul_name)
matmul_node.set_input(0, next_node.outputs[0])
if i == 0:
mm_b_value = om.get_initializer_value(matmul_node.inputs[1])
mm_b_value = np.concatenate([mm_b_value[128:256, ...],
mm_b_value[0:128, ...],
mm_b_value[256:, ...]],
axis=0)
om.set_initializer_value(matmul_node.inputs[1], mm_b_value)
om.update_map()
om.infer_shape()
def fuse_one_attention(om: ONNXModifier, softmax_name: str, new_mask: bool = None, num_heads: int = 12):
softmax_node = om.get_node(softmax_name)
tmp_node = om.get_prev_nodes(softmax_node)[0]
assert tmp_node.op_type in ["MatMul", "Add"]
mask = None
if tmp_node.op_type == "Add":
mask_node = tmp_node
tmp_node = om.get_from_node(mask_node.inputs[0])
if tmp_node.op_type == "Div":
tmp_node = om.get_from_node(tmp_node.inputs[0])
assert tmp_node.op_type == "MatMul"
mask = mask_node.inputs[1]
assert new_mask is not None
tmp_node1 = om.get_from_node(tmp_node.inputs[0])
if tmp_node1.op_type == "Mul":
tmp_node1 = om.get_prev_nodes(tmp_node1)[0]
tmp_node2 = om.get_from_node(tmp_node.inputs[1])
assert tmp_node1.op_type == tmp_node2.op_type == "Transpose"
tmp_node1 = om.get_prev_nodes(tmp_node1)[0]
tmp_node2 = om.get_prev_nodes(tmp_node2)[0]
assert tmp_node1.op_type == tmp_node2.op_type == "Reshape"
q, k = tmp_node1.inputs[0], tmp_node2.inputs[0]
tmp_node = om.get_next_nodes(softmax_node)[0]
assert tmp_node.op_type == "MatMul"
tmp_node3 = om.get_from_node(tmp_node.inputs[1])
if tmp_node3 is not None:
assert tmp_node3.op_type == "Transpose"
tmp_node3 = om.get_prev_nodes(tmp_node3)[0]
assert tmp_node3.op_type == "Reshape"
v = tmp_node3.inputs[0]
else:
v_init = om.get_initializer(tmp_node.inputs[1])
v_init_value = om.get_initializer_value(tmp_node.inputs[1])
v_init_value = v_init_value[None, ...].transpose(0, 2, 1, 3)
B, S, H, D = v_init_value.shape
v_init_value = np.ascontiguousarray(v_init_value.reshape(B, S, H*D))
om.set_initializer_value(tmp_node.inputs[1], v_init_value)
v = v_init.name
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
mha_next_node = om.get_next_nodes(tmp_node)[0]
if mha_next_node.op_type == "Gemm":
gemm_next_node = om.get_next_nodes(mha_next_node)[0]
assert gemm_next_node.op_type == "Reshape"
reshape_next_node = om.get_next_nodes(gemm_next_node)[0]
assert reshape_next_node.op_type == "Add"
else:
assert mha_next_node.op_type == "MatMul"
name_prefix = '/'.join(softmax_name.split('/')[:-1])
mha_name = f"{name_prefix}/MultiHeadAttention"
mha_node = om.create_node("MultiHeadAttention",
mha_name,
[q, k, v] if mask is None else [q, k, v, new_mask],
[mha_name+'_output_0'],
num_heads=num_heads,
domain="com.microsoft",
index=mha_next_node.index-1)
mha_next_node.replace_input(mha_next_node.inputs[0], mha_node.outputs[0])
if mha_next_node.op_type == "Gemm":
weights = om.get_initializer_value(mha_next_node.inputs[1])
transB = mha_next_node.attrs["transB"]
assert transB == 1
weights = np.ascontiguousarray(weights.transpose(1, 0))
om.set_initializer_value(mha_next_node.inputs[1], weights)
new_matmul_name = mha_next_node.name.replace("Gemm", "MatMul(Gemm)")
new_matmul_node = om.create_node("MatMul",
new_matmul_name,
[mha_node.outputs[0], mha_next_node.inputs[1]],
[new_matmul_name + "_output_0"],
index=mha_next_node.index)
new_bias_name = mha_next_node.name.replace("Gemm", "Add(Gemm)")
new_add_node = om.create_node("Add",
new_bias_name,
[new_matmul_node.outputs[0], mha_next_node.inputs[2]],
[new_bias_name + "_output_0"],
index=new_matmul_node.index+1)
reshape_next_node.replace_input(gemm_next_node.outputs[0], new_add_node.outputs[0])
def optimize_normal_attention(om: ONNXModifier):
def create_new_attention_mask():
mask_next_node = om.get_to_nodes("attention_mask")[0]
cast_node = om.create_node("Cast",
"Cast_for_attention_mask",
["attention_mask"],
["Cast_for_attention_mask_output_0"],
to=1,
index=mask_next_node.index)
reducesum_node = om.create_node("ReduceSum",
"ReduceSum_for_mask",
[cast_node.outputs[0]],
["ReduceSum_for_mask_output_0"],
axes=1,
keepdims=0,
index=cast_node.index+1)
return reducesum_node.outputs[0]
# bert
# for i in range(12):
# fuse_one_attention(om, f"/bert/encoder/layer.{i}/attention/self/Softmax", "text_token_mask", num_heads=12)
new_mask = create_new_attention_mask()
for i in range(6):
# /transformer/encoder
# fuse_one_attention(om, f"/transformer/encoder/text_layers.{i}/self_attn/Softmax", "text_token_mask", num_heads=4)
# /transformer/decoder
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/self_attn/Softmax", new_mask, num_heads=8)
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/ca_text/Softmax", new_mask, num_heads=8)
om.update_map()
def optimize_backbone_attention(om: ONNXModifier):
def get_original_mask(mask_name, name_prefix):
mask_value = om.get_initializer_value(mask_name)
orig_mask = np.where(mask_value==0, 1, 0).astype(np.bool_)
orig_mask_init = om.create_initializer(f"{name_prefix}/mask", orig_mask)
return orig_mask_init.name
def _fuse_one_attention(softmax_name: str):
name_prefix = '/'.join(softmax_name.split('/')[:-1])
softmax_node = om.get_node(softmax_name)
tmp_node = om.get_prev_nodes(softmax_node)[0]
pos_bias_init = None
if tmp_node.op_type == "Reshape":
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Add"
pos_bias_init = om.get_initializer(tmp_node.inputs[1])
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Add"
mask = get_original_mask(tmp_node.inputs[1], name_prefix)
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "MatMul"
qk_matmul = tmp_node
tmp_node = om.get_from_node(qk_matmul.inputs[0])
assert tmp_node.op_type == "Mul"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Gather"
q_gather_node = tmp_node
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
reshape_node = tmp_node
tmp_node = om.get_from_node(qk_matmul.inputs[1])
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Gather"
k_gather_node = tmp_node
tmp_node = om.get_next_nodes(softmax_node)[0]
assert tmp_node.op_type == "MatMul"
v_gather_node = om.get_from_node(tmp_node.inputs[1])
assert v_gather_node.op_type == "Gather"
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
mha_out = tmp_node.outputs[0]
old_dst_shape = om.get_initializer_value(reshape_node.inputs[1])
b, s, _, h, d = old_dst_shape
new_dst_shape = [b, s, _, h*d]
new_dst_shape_init = om.create_initializer(f"{name_prefix}/qkv_hidden_states_shape",
np.array(new_dst_shape, np.int64))
reshape_node.set_input(1, new_dst_shape_init.name)
for node in [q_gather_node, k_gather_node, v_gather_node]:
node.set_input(0, reshape_node.outputs[0])
node.set_attribute("axis", 2)
mha_name = f"{name_prefix}/MultiHeadAttention"
inputs = [q_gather_node.outputs[0],
k_gather_node.outputs[0],
v_gather_node.outputs[0],
mask]
if pos_bias_init is not None:
inputs.append(pos_bias_init.name)
mha_node = om.create_node("MultiHeadAttention",
mha_name,
inputs,
[mha_name+'_output_0'],
num_heads=h,
domain="com.microsoft",
index=softmax_node.index)
mha_next_node = om.get_to_nodes(mha_out)[0]
mha_next_node.replace_input(mha_out, mha_node.outputs[0])
num_layers = 4
for l in range(num_layers):
num_blocks = 18 if l == 2 else 2
for b in range(num_blocks):
_fuse_one_attention(f"/backbone/backbone.0/layers.{l}/blocks.{b}/attn/softmax/Softmax")
def optimize_ms_deform_attn(om: ONNXModifier):
def fuse_ms_deform_attn(value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, output):
value_next_node = om.get_to_nodes(value)[0]
index = value_next_node.index
name_prefix = '/'.join(value.split('/')[:-1])
node_name = f"{name_prefix}/MSDeformAttn"
fusion_node = om.create_node("MSDeformAttn",
node_name,
[value, spatial_shapes, level_start_index,
sampling_locations, attention_weights],
[f"{node_name}_output_0"],
index=index)
next_nodes = om.get_to_nodes(output)
for node in next_nodes:
node.replace_input(output, fusion_node.outputs[0])
spatial_shapes_int = om.create_initializer(
"/transformer/spatial_shapes",
np.array([(100, 150), (50, 75), (25, 38), (13, 19)], dtype=np.int64)
)
level_start_index_init = om.create_initializer(
"/transformer/level_start_index",
np.array([0, 15000, 18750, 19700], dtype=np.int64)
)
for i in range(6):
fuse_ms_deform_attn(
f"/transformer/encoder/layers.{i}/self_attn/Reshape_output_0",
spatial_shapes_int.name,
level_start_index_init.name,
f"/transformer/encoder/layers.{i}/self_attn/Add_output_0",
f"/transformer/encoder/layers.{i}/self_attn/Reshape_3_output_0",
f"/transformer/encoder/layers.{i}/self_attn/Transpose_9_output_0"
)
fuse_ms_deform_attn(
f"/transformer/decoder/layers.{i}/cross_attn/Reshape_output_0",
spatial_shapes_int.name,
level_start_index_init.name,
f"/transformer/decoder/layers.{i}/cross_attn/Add_output_0",
f"/transformer/decoder/layers.{i}/cross_attn/Reshape_3_output_0",
f"/transformer/decoder/layers.{i}/cross_attn/Transpose_9_output_0"
)
om.update_map()
def optimize_bidirect_attention(om: ONNXModifier):
for i in range(6):
reduce_max_name = f"/transformer/encoder/fusion_layers.{i}/attn/ReduceMax_1"
reduce_max_node = om.get_node(reduce_max_name)
next_node = om.get_next_nodes(reduce_max_node)[0]
assert next_node.op_type == "Sub"
name_prefix = '/'.join(reduce_max_name.split('/')[:-1])
matmul_name = f"{name_prefix}/identity_MatMul"
matmul_init = om.create_initializer(matmul_name + "_B",
np.diag(np.array([1] * 1)).astype(np.float32))
matmul_node = om.create_node(
"MatMul",
matmul_name,
[reduce_max_node.outputs[0], matmul_init.name],
[f"{matmul_name}_output_0"],
index = reduce_max_node.index+1
)
next_node.set_input(1, matmul_node.outputs[0])
def main():
input_onnx_path = sys.argv[1]
output_onnx_path = sys.argv[2]
# input_onnx_path = "ground_sim.onnx"
# output_onnx_path = "ground_sim_0424_2nd.onnx"
om = ONNXModifier(input_onnx_path)
optimize_where_ndoes(om) # 1. 替换where节点
optimize_transpose_nodes(om) # 2. 优化transpose节点
optmize_sin_cos_block(om) # 3. 优化位置编码
# om.add_opset_import("com.microsoft", 1)
# optimize_normal_attention(om) # 4. 融合bert、transformer中的mha
# optimize_ms_deform_attn(om) # 5. 融合多尺度可变形注意力
# optimize_backbone_attention(om) # 6. 融合backbone中的注意力
optimize_bidirect_attention(om) # 7. 优化双向注意力
om.save(output_onnx_path, save_as_external_data=False)
if __name__ == "__main__":
main()
......@@ -188,7 +188,9 @@ def optimize_transpose_nodes(om: ONNXModifier):
om.get_node("/transformer/enc_out_class_embed/Transpose").set_attribute("perm", [0, 2, 1])
# modify /transformer/decoder/Reshape_*
om.set_initializer_value("_v_5525", np.array([1, 900, -1], np.int64))
dst_shape_name = om.get_node("/transformer/decoder/Reshape").inputs[1]
# om.set_initializer_value("_v_5525", np.array([1, 900, -1], np.int64))
om.set_initializer_value(dst_shape_name, np.array([1, 900, -1], np.int64))
# modify /transformer/decoder/layers.*/self_attn/Reshape_4
# modify /transformer/decoder/layers.*/ca_text/Reshape_6
......@@ -313,7 +315,8 @@ def optmize_sin_cos_block(om: ONNXModifier):
om.infer_shape()
def fuse_one_attention(om: ONNXModifier, softmax_name: str, new_mask: bool = None, num_heads: int = 12):
def fuse_one_attention(om: ONNXModifier, softmax_name: str, padding_mask: bool = None,
attn_mask: bool = None, num_heads: int = 12, block_type: str = "bert"):
softmax_node = om.get_node(softmax_name)
tmp_node = om.get_prev_nodes(softmax_node)[0]
assert tmp_node.op_type in ["MatMul", "Add"]
......@@ -326,17 +329,34 @@ def fuse_one_attention(om: ONNXModifier, softmax_name: str, new_mask: bool = Non
tmp_node = om.get_from_node(tmp_node.inputs[0])
assert tmp_node.op_type == "MatMul"
mask = mask_node.inputs[1]
assert new_mask is not None
if padding_mask is not None and attn_mask is not None:
raise ValueError("padding_mask and attn_mask cannot be provided at the same time")
if padding_mask is None and attn_mask is None:
raise ValueError("padding_mask or attn_mask must be provided")
tmp_node1 = om.get_from_node(tmp_node.inputs[0])
if tmp_node1.op_type == "Mul":
tmp_node1 = om.get_prev_nodes(tmp_node1)[0]
tmp_node2 = om.get_from_node(tmp_node.inputs[1])
if tmp_node2.op_type == "Mul":
tmp_node2 = om.get_prev_nodes(tmp_node2)[0]
assert tmp_node1.op_type == tmp_node2.op_type == "Transpose"
tmp_node1 = om.get_prev_nodes(tmp_node1)[0]
tmp_node2 = om.get_prev_nodes(tmp_node2)[0]
assert tmp_node1.op_type == tmp_node2.op_type == "Reshape"
q, k = tmp_node1.inputs[0], tmp_node2.inputs[0]
if attn_mask is not None:
q, k = tmp_node1.outputs[0], tmp_node2.outputs[0]
q_dst_shape_value = om.get_initializer_value(tmp_node1.inputs[1])
if q_dst_shape_value.size == 3:
q_dst_shape_value_new = np.array([1, *q_dst_shape_value.tolist()], np.int64)
om.set_initializer_value(tmp_node1.inputs[1], q_dst_shape_value_new)
k_dst_shape_value = om.get_initializer_value(tmp_node2.inputs[1])
if k_dst_shape_value.size == 3:
k_dst_shape_value_new = np.array([1, *k_dst_shape_value.tolist()], np.int64)
om.set_initializer_value(tmp_node2.inputs[1], k_dst_shape_value_new)
else:
q, k = tmp_node1.inputs[0], tmp_node2.inputs[0]
tmp_node = om.get_next_nodes(softmax_node)[0]
assert tmp_node.op_type == "MatMul"
......@@ -345,7 +365,14 @@ def fuse_one_attention(om: ONNXModifier, softmax_name: str, new_mask: bool = Non
assert tmp_node3.op_type == "Transpose"
tmp_node3 = om.get_prev_nodes(tmp_node3)[0]
assert tmp_node3.op_type == "Reshape"
v = tmp_node3.inputs[0]
if attn_mask is not None:
v = tmp_node3.outputs[0]
v_dst_shape_value = om.get_initializer_value(tmp_node3.inputs[1])
if v_dst_shape_value.size == 3:
v_dst_shape_value_new = np.array([1, *v_dst_shape_value.tolist()], np.int64)
om.set_initializer_value(tmp_node3.inputs[1], v_dst_shape_value_new)
else:
v = tmp_node3.inputs[0]
else:
v_init = om.get_initializer(tmp_node.inputs[1])
v_init_value = om.get_initializer_value(tmp_node.inputs[1])
......@@ -359,54 +386,79 @@ def fuse_one_attention(om: ONNXModifier, softmax_name: str, new_mask: bool = Non
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
mha_next_node = om.get_next_nodes(tmp_node)[0]
if mha_next_node.op_type == "Gemm":
gemm_next_node = om.get_next_nodes(mha_next_node)[0]
# if softmax_name == "/transformer/encoder/text_layers.0/self_attn/Softmax":
# breakpoint()
if attn_mask is not None:
mha_next_node = tmp_node
mha_out_shape_value = om.get_initializer_value(tmp_node.inputs[1])
if mha_out_shape_value.size == 2:
mha_out_shape_value_new = np.array([1, -1, mha_out_shape_value[-1].item()], np.int64)
om.set_initializer_value(tmp_node.inputs[1], mha_out_shape_value_new)
else:
mha_next_node = om.get_next_nodes(tmp_node)[0]
assert mha_next_node.op_type in ["Gemm", "MatMul"]
gemm_node = None
if om.get_next_nodes(tmp_node)[0].op_type == "Gemm":
gemm_node = om.get_next_nodes(tmp_node)[0]
gemm_next_node = om.get_next_nodes(gemm_node)[0]
assert gemm_next_node.op_type == "Reshape"
reshape_next_node = om.get_next_nodes(gemm_next_node)[0]
assert reshape_next_node.op_type == "Add"
else:
assert mha_next_node.op_type == "MatMul"
name_prefix = '/'.join(softmax_name.split('/')[:-1])
mha_name = f"{name_prefix}/MultiHeadAttention"
mha_node = om.create_node("MultiHeadAttention",
mha_name,
[q, k, v] if mask is None else [q, k, v, new_mask],
[mha_name+'_output_0'],
num_heads=num_heads,
domain="com.microsoft",
index=mha_next_node.index-1)
node_inputs = [q, k, v]
if mask is None:
mha_type = "MultiHeadAttention"
else:
if padding_mask is not None:
mha_type = "MultiHeadAttention"
node_inputs.append(padding_mask)
elif attn_mask is not None:
mha_type = "MultiHeadAttentionWithAttnMask"
node_inputs.append(attn_mask)
else:
raise ValueError("padding_mask or attn_mask must be provided")
mha_name = f"{name_prefix}/{mha_type}"
mha_node = om.create_node(mha_type,
mha_name,
node_inputs,
[mha_name+'_output_0'],
num_heads=num_heads,
domain="com.microsoft",
index=mha_next_node.index-1)
mha_next_node.replace_input(mha_next_node.inputs[0], mha_node.outputs[0])
if mha_next_node.op_type == "Gemm":
weights = om.get_initializer_value(mha_next_node.inputs[1])
transB = mha_next_node.attrs["transB"]
if gemm_node is not None:
weights = om.get_initializer_value(gemm_node.inputs[1])
transB = gemm_node.attrs["transB"]
assert transB == 1
weights = np.ascontiguousarray(weights.transpose(1, 0))
om.set_initializer_value(mha_next_node.inputs[1], weights)
new_matmul_name = mha_next_node.name.replace("Gemm", "MatMul(Gemm)")
om.set_initializer_value(gemm_node.inputs[1], weights)
new_matmul_name = gemm_node.name.replace("Gemm", "MatMul(Gemm)")
new_matmul_node = om.create_node("MatMul",
new_matmul_name,
[mha_node.outputs[0], mha_next_node.inputs[1]],
[new_matmul_name + "_output_0"],
index=mha_next_node.index)
new_bias_name = mha_next_node.name.replace("Gemm", "Add(Gemm)")
new_matmul_name,
[mha_next_node.outputs[0] if attn_mask is not None else mha_node.outputs[0],
gemm_node.inputs[1]],
[new_matmul_name + "_output_0"],
index=gemm_node.index)
new_bias_name = gemm_node.name.replace("Gemm", "Add(Gemm)")
new_add_node = om.create_node("Add",
new_bias_name,
[new_matmul_node.outputs[0], mha_next_node.inputs[2]],
[new_bias_name + "_output_0"],
index=new_matmul_node.index+1)
new_bias_name,
[new_matmul_node.outputs[0], gemm_node.inputs[2]],
[new_bias_name + "_output_0"],
index=new_matmul_node.index+1)
reshape_next_node.replace_input(gemm_next_node.outputs[0], new_add_node.outputs[0])
def optimize_normal_attention(om: ONNXModifier):
def create_new_attention_mask():
def _create_new_padding_mask():
mask_next_node = om.get_to_nodes("attention_mask")[0]
cast_node = om.create_node("Cast",
"Cast_for_attention_mask",
"Cast_for_padding_mask",
["attention_mask"],
["Cast_for_attention_mask_output_0"],
["Cast_for_padding_mask_output_0"],
# to=1, # float32
to=6, # int32
index=mask_next_node.index)
......@@ -419,17 +471,32 @@ def optimize_normal_attention(om: ONNXModifier):
index=cast_node.index+1)
return reducesum_node.outputs[0]
def _create_new_attn_mask(_num_heads: int):
cast_node = om.get_node("/bert/Cast")
tile_node = om.create_node("Tile",
f"Tile_for_attn_mask_{_num_heads}heads",
["/bert/Unsqueeze_output_0",
om.create_initializer(f"{_num_heads}heads_repeats",
np.array([1, _num_heads, 1, 1], np.int64)).name],
[f"Tile_for_attn_mask_{_num_heads}heads_output_0"],
index=cast_node.index)
return tile_node.outputs[0]
padding_mask = _create_new_padding_mask()
attn_mask1 = _create_new_attn_mask(12)
attn_mask2 = _create_new_attn_mask(4)
# bert
# for i in range(12):
# fuse_one_attention(om, f"/bert/encoder/layer.{i}/attention/self/Softmax", "text_token_mask", num_heads=12)
for i in range(12):
fuse_one_attention(om, f"/bert/encoder/layer.{i}/attention/self/Softmax", attn_mask=attn_mask1, num_heads=12)
new_mask = create_new_attention_mask()
for i in range(6):
# /transformer/encoder
# fuse_one_attention(om, f"/transformer/encoder/text_layers.{i}/self_attn/Softmax", "text_token_mask", num_heads=4)
fuse_one_attention(om, f"/transformer/encoder/text_layers.{i}/self_attn/Softmax", attn_mask=attn_mask2, num_heads=4)
# /transformer/decoder
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/self_attn/Softmax", new_mask, num_heads=8)
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/ca_text/Softmax", new_mask, num_heads=8)
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/self_attn/Softmax", padding_mask=padding_mask, num_heads=8)
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/ca_text/Softmax", padding_mask=padding_mask, num_heads=8)
om.update_map()
......@@ -437,25 +504,32 @@ def optimize_normal_attention(om: ONNXModifier):
def optimize_backbone_attention(om: ONNXModifier):
def get_original_mask(mask_name, name_prefix):
mask_value = om.get_initializer_value(mask_name)
orig_mask = np.where(mask_value==0, 1, 0).astype(np.bool_)
orig_mask_init = om.create_initializer(f"{name_prefix}/mask", orig_mask)
return orig_mask_init.name
# mask_value = mask_value.astype(np.int64)
assert mask_value.ndim == 5 and mask_value.shape[2] == 1
mask_value = mask_value.reshape(mask_value.shape[0] * mask_value.shape[1], mask_value.shape[3], mask_value.shape[4])
# = np.where(mask_value==0, 1, 0).astype(np.bool_)
new_mask_init = om.create_initializer(f"{name_prefix}/mask", mask_value)
return new_mask_init.name
def _fuse_one_attention(softmax_name: str):
name_prefix = '/'.join(softmax_name.split('/')[:-1])
def _fuse_one_attention_with_bias(softmax_name: str):
name_prefix = '/'.join(softmax_name.split('/')[:-2])
softmax_node = om.get_node(softmax_name)
tmp_node = om.get_prev_nodes(softmax_node)[0]
pos_bias_init = None
mask = None
if tmp_node.op_type == "Reshape":
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Add"
pos_bias_init = om.get_initializer(tmp_node.inputs[1])
mask = get_original_mask(tmp_node.inputs[1], name_prefix)
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "Add"
mask = get_original_mask(tmp_node.inputs[1], name_prefix)
pos_bias_init_value = om.get_initializer_value(tmp_node.inputs[1])
assert pos_bias_init_value.shape[0] == 1
om.set_initializer_value(tmp_node.inputs[1], pos_bias_init_value.squeeze(axis=0))
pos_bias_init = tmp_node.inputs[1]
tmp_node = om.get_prev_nodes(tmp_node)[0]
assert tmp_node.op_type == "MatMul"
......@@ -486,26 +560,27 @@ def optimize_backbone_attention(om: ONNXModifier):
assert tmp_node.op_type == "Transpose"
tmp_node = om.get_next_nodes(tmp_node)[0]
assert tmp_node.op_type == "Reshape"
mha_out = tmp_node.outputs[0]
# mha_out = tmp_node.outputs[0]
mha_out = tmp_node.inputs[0]
old_dst_shape = om.get_initializer_value(reshape_node.inputs[1])
b, s, _, h, d = old_dst_shape
new_dst_shape = [b, s, _, h*d]
new_dst_shape_init = om.create_initializer(f"{name_prefix}/qkv_hidden_states_shape",
np.array(new_dst_shape, np.int64))
reshape_node.set_input(1, new_dst_shape_init.name)
# new_dst_shape = [b, s, _, h*d]
# new_dst_shape_init = om.create_initializer(f"{name_prefix}/qkv_hidden_states_shape",
# np.array(new_dst_shape, np.int64))
# reshape_node.set_input(1, new_dst_shape_init.name)
for node in [q_gather_node, k_gather_node, v_gather_node]:
node.set_input(0, reshape_node.outputs[0])
node.set_attribute("axis", 2)
mha_name = f"{name_prefix}/MultiHeadAttention"
mha_name = f"{name_prefix}/MultiHeadAttentionWithAttnMask"
inputs = [q_gather_node.outputs[0],
k_gather_node.outputs[0],
v_gather_node.outputs[0],
mask]
if pos_bias_init is not None:
inputs.append(pos_bias_init.name)
mha_node = om.create_node("MultiHeadAttention",
pos_bias_init]
if mask is not None:
inputs.append(mask)
mha_node = om.create_node("MultiHeadAttentionWithAttnMask",
mha_name,
inputs,
[mha_name+'_output_0'],
......@@ -519,7 +594,7 @@ def optimize_backbone_attention(om: ONNXModifier):
for l in range(num_layers):
num_blocks = 18 if l == 2 else 2
for b in range(num_blocks):
_fuse_one_attention(f"/backbone/backbone.0/layers.{l}/blocks.{b}/attn/softmax/Softmax")
_fuse_one_attention_with_bias(f"/backbone/backbone.0/layers.{l}/blocks.{b}/attn/softmax/Softmax")
def optimize_ms_deform_attn(om: ONNXModifier):
......@@ -614,7 +689,7 @@ def main():
input_onnx_path = sys.argv[1]
output_onnx_path = sys.argv[2]
# input_onnx_path = "ground_sim.onnx"
# output_onnx_path = "ground_sim_0520_new.onnx"
# output_onnx_path = "ground_sim_0529.onnx"
om = ONNXModifier(input_onnx_path)
optimize_where_ndoes(om) # 1. 替换where节点
......@@ -622,9 +697,9 @@ def main():
optmize_sin_cos_block(om) # 3. 优化位置编码
om.add_opset_import("com.microsoft", 1)
optimize_normal_attention(om) # 4. 融合bert、transformer中的mha
# optimize_backbone_attention(om) # 5. 融合backbone中的注意力
optimize_backbone_attention(om) # 5. 融合backbone中的注意力
optimize_ms_deform_attn(om) # 6. 融合多尺度可变形注意力
optimize_bidirect_attention(om) # 7. 优化双向注意力
# optimize_bidirect_attention(om) # 7. 优化双向注意力
om.save(output_onnx_path, save_as_external_data=False)
......
......@@ -240,7 +240,7 @@ def benchmark_performance(
if __name__ == '__main__':
# 配置参数
model_path = 'weights_400x600/ground.onnx'
model_path = 'weights/ground.onnx'
img_path = 'images/in/car_1.jpg'
TEXT_PROMPT = "car ."
BOX_TRESHOLD = 0.35
......
......@@ -38,7 +38,8 @@ install_torch()
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
# from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME, CppExtension, CUDAExtension
# groundingdino version info
version = "0.1.0"
......@@ -82,7 +83,26 @@ def get_extensions():
extra_compile_args = {"cxx": []}
define_macros = []
if CUDA_HOME is not None and (torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ):
# if CUDA_HOME is not None and (torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ):
cuda_home = CUDA_HOME or os.environ.get("CUDA_HOME") or ROCM_HOME or os.environ.get("ROCM_HOME") or os.environ.get("HIP_HOME")
print("DEBUG imported CUDA_HOME:", CUDA_HOME)
print("DEBUG imported ROCM_HOME:", ROCM_HOME)
print("DEBUG env CUDA_HOME:", os.environ.get("CUDA_HOME"))
print("DEBUG env ROCM_HOME:", os.environ.get("ROCM_HOME"))
print("DEBUG env HIP_HOME:", os.environ.get("HIP_HOME"))
print("DEBUG cuda_home used:", cuda_home)
print("DEBUG torch.cuda.is_available:", torch.cuda.is_available())
print("DEBUG TORCH_CUDA_ARCH_LIST:", os.environ.get("TORCH_CUDA_ARCH_LIST"))
print("DEBUG FORCE_CUDA:", os.environ.get("FORCE_CUDA"))
if cuda_home is not None and (
torch.cuda.is_available()
or "TORCH_CUDA_ARCH_LIST" in os.environ
or os.getenv("FORCE_CUDA", "0") == "1"
):
print("Compiling with CUDA")
extension = CUDAExtension
sources += source_cuda
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment