Commit 39a85c88 authored by zk's avatar zk
Browse files

新增migraphx脚本推理

parent a1865640
...@@ -216,6 +216,16 @@ bash migraphx_export.bash ...@@ -216,6 +216,16 @@ bash migraphx_export.bash
bash migraphx_perf.bash bash migraphx_perf.bash
``` ```
4. 使用python脚本测试
```bash
python migraphx_infer.py
# offload=False推理,提前开辟gpu空间,数据放在device推理
python migraphx_infer1.py
# offload=True推理,会慢一些
```
----- -----
## 8\. 测试结果对比 ## 8\. 测试结果对比
...@@ -252,7 +262,8 @@ bash migraphx_perf.bash ...@@ -252,7 +262,8 @@ bash migraphx_perf.bash
| **ORT + Plugin** | +自定义算子<br>+FP16 纯量化方案 B | `ground_deform_fp16_all.onnx` | `ort_plugin_fp16_B` | 105.35 | 9.49 | | **ORT + Plugin** | +自定义算子<br>+FP16 纯量化方案 B | `ground_deform_fp16_all.onnx` | `ort_plugin_fp16_B` | 105.35 | 9.49 |
| **ORT + Plugin** | +自定义算子<br>+FP16 极致优化方案 C | `ground_deform_fp16_all.onnx` | `ort_plugin_fp16_C` | 100.91 | 9.90 | | **ORT + Plugin** | +自定义算子<br>+FP16 极致优化方案 C | `ground_deform_fp16_all.onnx` | `ort_plugin_fp16_C` | 100.91 | 9.90 |
### 8.3 migraphx BW100 测试结果 ### 8.3 migraphx BW150和BW100 测试结果
BW100示例结果:
``` ```
Batch size: 1 Batch size: 1
Rate: 6.05197 inferences/sec Rate: 6.05197 inferences/sec
...@@ -263,6 +274,15 @@ Total instructions time: 205.275ms ...@@ -263,6 +274,15 @@ Total instructions time: 205.275ms
Overhead time: 2.32812ms, -40.0399ms Overhead time: 2.32812ms, -40.0399ms
Overhead: 1%, -24% Overhead: 1%, -24%
``` ```
汇总结果
| 设备 | 推理方式 | FPS | 平均推理时间 (ms) |
| :--- | :--- | :--- | :--- |
| BW150 | migraphx-driver | 14.93 | 66.97 |
| BW150 | Python + MIGraphX(device) | 13.65 | 73.20(包含前后处理) |
| BW100 | migraphx-driver | 13.54 | 73.87 |
| BW100 | Python + MIGraphX(device) | 12.12 | 82.44(包含前后处理) |
----- -----
## 参考项目 ## 参考项目
......
...@@ -214,6 +214,7 @@ if __name__ == '__main__': ...@@ -214,6 +214,7 @@ if __name__ == '__main__':
image_source, image = load_image(img_path) image_source, image = load_image(img_path)
providers = [ providers = [
# 'MIGraphXExecutionProvider',
'ROCMExecutionProvider', 'ROCMExecutionProvider',
'CPUExecutionProvider' 'CPUExecutionProvider'
] ]
......
export MIGRAPHX_ENABLE_MIOPEN_CONCAT=1 export MIGRAPHX_TRACE_COMPILE=1
migraphx-driver perf --onnx \ migraphx-driver perf --onnx \
../weights/ground_opt.onnx \ ../weights/ground_opt_0430.onnx \
--fp16 \ --fp16 \
--output \ --output \
../weights/ground_opt.mxr ../weights/ground_opt_0430.mxr
\ No newline at end of file
# ../weights/ground_opt_0430.mxr > migraphx_log.log 2>&1
\ No newline at end of file
...@@ -3,203 +3,208 @@ import numpy as np ...@@ -3,203 +3,208 @@ import numpy as np
import torch import torch
import time import time
import os import os
import bisect
import migraphx import migraphx
from typing import Tuple, List, Dict
from transformers import BertTokenizer import groundingdino.datasets.transforms as T
from groundingdino.util.inference import load_image from PIL import Image
from groundingdino.models.GroundingDINO.bertwarper import generate_masks_with_special_tokens_and_transfer_map
# ========================= # =========================
# 工具函数 # 预处理
# ========================= # =========================
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
]
)
image_source = Image.open(image_path).convert("RGB")
image = np.asarray(image_source)
image_transformed, _ = transform(image_source, None)
return image, image_transformed
def sigmoid(x): def sigmoid(x):
return 1 / (1 + np.exp(-x)) return 1 / (1 + np.exp(-x))
def preprocess_caption(caption: str) -> str:
result = caption.lower().strip() # =========================
if result.endswith("."): # 文本标签还原逻辑 (移除 Tokenizer 依赖)
return result # =========================
return result + "." def get_phrases_from_posmap(
posmap: np.ndarray, tokens: List[str], left_idx: int = 0, right_idx: int = 255
def to_mgx(x): ):
if x.dtype == np.int64: """
return migraphx.argument(x.astype(np.int64)) 直接用字符串列表映射,抛弃沉重的 Tokenizer
elif x.dtype == np.bool_: """
return migraphx.argument(x.astype(np.bool_)) assert isinstance(posmap, np.ndarray), "posmap must be np.ndarray"
else: if posmap.ndim == 1:
return migraphx.argument(x.astype(np.float32)) # 将指定范围内的元素设为 False
posmap[:left_idx + 1] = False
posmap[right_idx:] = False
def _mgx_shape_to_numpy(shape):
# 将 migraphx input shape 映射到 numpy dtype + lens 以生成零填充张量 # 获取非零元素的索引
shape_str = str(shape) non_zero_idx = np.nonzero(posmap)[0]
if "int64_type" in shape_str: # 提取被激活的单词,并自动过滤掉特殊占位符
dtype = np.int64 words = [tokens[i] for i in non_zero_idx if tokens[i] not in ["[CLS]", "[SEP]", "."]]
elif "bool_type" in shape_str: return " ".join(words).strip()
dtype = np.bool_
elif "half_type" in shape_str:
dtype = np.float16
else: else:
dtype = np.float32 raise NotImplementedError("posmap must be 1-dim")
try:
dims = list(shape.dims())
except Exception: # =========================
dims = [] # 分配输出 GPU 内存 (offload_copy=False 必须)
try: # =========================
lens = list(shape.lens()) def allocate_output_memory(model):
except Exception: output_data = {}
lens = [] for key in model.get_outputs().keys():
# 优先用 dims,dims 为空时才退化到 lens output_data[key] = migraphx.allocate_gpu(
return dtype, (dims if len(dims) > 0 else lens) s=model.get_outputs()[key]
)
return output_data
# ========================= # =========================
# 🚀 MIGraphX 推理类(带缓存) # MIGraphX 模型类
# ========================= # =========================
class MIGraphXModel: class MIGraphXModel:
def __init__(self, onnx_path, cache_path="weights/ground_opt.mxr", force_recompile=False): def __init__(self,
onnx_path,
cache_path="../weights/ground_opt_0430.mxr",
device_id=3,
force_recompile=False):
self.cache_path = cache_path self.cache_path = cache_path
# ====== 优先加载缓存 ======
if os.path.exists(cache_path) and not force_recompile: if os.path.exists(cache_path) and not force_recompile:
print(f"⚡ 直接加载已编译模型: {cache_path}") print(f"⚡ 直接加载缓存模型: {cache_path}")
self.model = migraphx.load(cache_path) self.model = migraphx.load(cache_path)
else: else:
print("🔍 从 ONNX 构建 MIGraphX") print("🔍 从 ONNX 构建模型")
self.model = migraphx.parse_onnx(onnx_path) self.model = migraphx.parse_onnx(onnx_path)
print(self.model)
print("\n=== 输入信息 ===")
# ====================== 2. 打印模型输入输出信息 ====================== for k, v in self.model.get_inputs().items():
print("=== 模型输入信息 ===") print(f"{k}: {v}")
inputs = self.model.get_inputs()
for key, value in inputs.items(): print("\n=== 输出信息 ===")
print(f"{key}: {value}") for k, v in self.model.get_outputs().items():
print(f"{k}: {v}")
print("\n=== 模型输出信息 ===")
outputs = self.model.get_outputs() print("\n⚙️ 编译模型(GPU + offload=false)")
for key, value in outputs.items():
print(f"{key}: {value}")
"""
=== 模型输入信息 ===
text_token_mask: bool_type, {1, 4, 4}, {16, 4, 1}
token_type_ids: int64_type, {1, 4}, {4, 1}
position_ids: int64_type, {1, 4}, {4, 1}
attention_mask: bool_type, {1, 4}, {4, 1}
input_ids: int64_type, {1, 4}, {4, 1}
img: float_type, {1, 3, 800, 1200}, {2880000, 960000, 1200, 1}
=== 模型输出信息 ===
boxes: float_type, {1, 900, 4}, {3600, 4, 1}
logits: float_type, {1, 900, 256}, {230400, 256, 1}
输入节点名称: text_token_mask
输入形状 (N, C, H, W): [1, 4, 4]
"""
# print("\n⚡ 量化模型(FP16)")
# migraphx.quantize_fp16(self.model)
print("⚙️ 编译 MIGraphX(GPU)")
self.model.compile( self.model.compile(
t=migraphx.get_target("gpu"),device_id=5 t=migraphx.get_target("gpu"),
offload_copy=False,
device_id=device_id
) )
# offload_copy=False, fast_math=False, exhaustive_tune=False
# ====== 保存缓存 ====== print(f"💾 保存 mxr: {cache_path}")
print(f"💾 保存编译模型到: {cache_path}")
migraphx.save(self.model, cache_path) migraphx.save(self.model, cache_path)
self.inputs = self.model.get_inputs()
self.outputs = self.model.get_outputs()
self.param_names = self.model.get_parameter_names() self.param_names = self.model.get_parameter_names()
self.input_shapes = self.model.get_inputs()
print("✅ param_names:", self.param_names) print("✅ param_names:", self.param_names)
print("✅ input_shape:", self.input_shapes) print("✅ input_shape:", self.inputs)
try: print("✅ output_shapes keys:", list(self.outputs.keys()))
self.output_shapes = self.model.get_outputs()
print("✅ output_shapes keys:", list(self.output_shapes.keys())) self.output_gpu = allocate_output_memory(self.model)
except Exception: print("✅ 模型初始化完成")
self.output_shapes = None
def infer(self, input_dict): def infer(self, input_dict):
# 只按模型 get_inputs() 定义的输入签名来组装 mgx_data = self.output_gpu.copy()
mgx_inputs = {}
provided_names = set(input_dict.keys()) for name in self.inputs.keys():
# 某些 mxr 会把内部输出别名也暴露到 get_parameter_names/get_inputs 里, data = input_dict[name]
# 这里显式排除 main:#output_*,避免把内部输出当成输入填充。 if data.dtype == np.float64:
required_names = { data = data.astype(np.float32)
k for k in self.input_shapes.keys() mgx_data[name] = migraphx.to_gpu(migraphx.argument(data))
if not str(k).startswith("main:#output")
}
missing = required_names - provided_names
if missing:
print("⚠️ 缺失模型输入,准备按 shape 自动补齐:")
for name in sorted(missing):
shape = self.input_shapes[name]
dtype, lens = _mgx_shape_to_numpy(shape)
mgx_inputs[name] = to_mgx(np.zeros(lens, dtype=dtype))
print(f" - {name}: shape={lens}, dtype={dtype.__name__}")
for name in (required_names & provided_names):
mgx_inputs[name] = to_mgx(input_dict[name])
# 额外的 key 不喂给模型,避免和内部签名冲突
extra = provided_names - required_names
if extra:
print("ℹ️ 有多余输入参数将被忽略:")
for name in sorted(extra):
print(f" - {name}")
start = time.time() start = time.time()
result = self.model.run(mgx_inputs) results = self.model.run(mgx_data)
infer_time = time.time() - start infer_time = time.time() - start
outputs = [np.array(r) for r in result] outputs = [
np.array(migraphx.from_gpu(r))
for r in results
]
return outputs, infer_time return outputs, infer_time
# ========================= # =========================
# 推理函数 # 推理逻辑 (引入真正的后处理还原)
# ========================= # =========================
def predict( def predict(
model, model,
tokenizer,
image, image,
caption, text_cache,
box_threshold, box_threshold,
text_threshold, text_threshold,
remove_combined=False,
is_benchmark=False is_benchmark=False
): ) -> Tuple[np.ndarray, np.ndarray, List[str]]:
# 提前针对car .生成对应输入 # 使用传入的 text_cache 替代硬编码
input_dict = { input_dict = {
"img": np.expand_dims(np.asarray(image), axis=0).astype(np.float32), "img": np.expand_dims(np.asarray(image), axis=0).astype(np.float32),
"position_ids": np.array([[0, 0, 1, 0]], dtype=np.int64), "input_ids": text_cache['input_ids'],
"input_ids": np.array([[101, 2482, 1012, 102]], dtype=np.int64), "attention_mask": text_cache['attention_mask'],
"token_type_ids": np.array([[0, 0, 0, 0]], dtype=np.int64), "position_ids": text_cache['position_ids'],
"text_token_mask": np.array([[ "token_type_ids": text_cache['token_type_ids'],
[True, False, False, False], "text_token_mask": text_cache['text_token_mask']
[False, True, True, False],
[False, True, True, False],
[False, False, False, True]
]], dtype=np.bool_),
"attention_mask": np.array([[True, True, True, True]], dtype=np.bool_)
} }
outputs, infer_time = model.infer(input_dict) outputs, infer_time = model.infer(input_dict)
if not is_benchmark: if not is_benchmark:
print(f"Inference time: {infer_time*1000:.2f} ms") print(f"Inference time: {infer_time:.3f}s")
logits = sigmoid(outputs[0][0]) t0 = time.time()
boxes = outputs[1][0] prediction_logits = sigmoid(outputs[0][0])
prediction_boxes = outputs[1][0]
post_time = time.time() - t0
max_values = np.max(logits, axis=1) if not is_benchmark:
print(f"post time: {post_time:.3f}s")
print(f"\n=== Debug Info ===")
print(f"Prediction logits shape: {prediction_logits.shape}")
print(f"Prediction boxes shape: {prediction_boxes.shape}")
print(f"Max logit value: {np.max(prediction_logits):.4f}")
print(f"Mean logit value: {np.mean(prediction_logits):.4f}")
# 1. 框过滤
max_values = np.max(prediction_logits, axis=1)
mask = max_values > box_threshold mask = max_values > box_threshold
logits = logits[mask] logits = prediction_logits[mask]
boxes = boxes[mask] boxes = prediction_boxes[mask]
phrases = ["object"] * len(boxes) tokens = text_cache['tokens']
input_ids = text_cache['input_ids'][0].tolist()
if remove_combined:
sep_idx = [i for i in range(len(input_ids)) if input_ids[i] in [101, 102, 1012]]
phrases = []
for logit in logits:
max_idx = logit.argmax()
insert_idx = bisect.bisect_left(sep_idx, max_idx)
right_idx = sep_idx[insert_idx]
left_idx = sep_idx[insert_idx - 1]
phrases.append(
get_phrases_from_posmap(logit > text_threshold, tokens, left_idx, right_idx)
)
else:
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokens)
for logit in logits
]
return boxes, np.max(logits, axis=1), phrases return boxes, np.max(logits, axis=1), phrases
...@@ -207,20 +212,62 @@ def predict( ...@@ -207,20 +212,62 @@ def predict(
# ========================= # =========================
# Benchmark # Benchmark
# ========================= # =========================
def benchmark(model, tokenizer, image, caption, box_th, text_th, warmup=5, runs=10): def benchmark_performance(
print("\n🔥 预热") model, image, text_cache, box_threshold, text_threshold,
for _ in range(warmup): warmup_runs=5, test_runs=10
predict(model, tokenizer, image, caption, box_th, text_th, True) ):
print("="*60)
print("\n🚀 测试") print("📊 开始性能测试(包含预热+实际推理)")
times = [] print("="*60)
for i in range(runs):
start = time.time() print(f"\n🔥 预热阶段({warmup_runs} 次)- 不计入性能统计")
predict(model, tokenizer, image, caption, box_th, text_th, True) warmup_start = time.time()
times.append(time.time() - start) for i in range(warmup_runs):
t0 = time.time()
print(f"\n平均耗时: {np.mean(times)*1000:.2f} ms") predict(model, image, text_cache, box_threshold, text_threshold, is_benchmark=True)
print(f"FPS: {1/np.mean(times):.2f}") warmup_time = time.time() - t0
print(f"预热 {i+1}/{warmup_runs} - 耗时: {warmup_time*1000:.2f} ms")
total_warmup_time = time.time() - warmup_start
print(f"\n预热完成 - 总耗时: {total_warmup_time:.3f} s, 平均每次: {total_warmup_time/warmup_runs*1000:.2f} ms")
print(f"\n🚀 实际推理测试阶段({test_runs} 次)- 统计性能指标")
test_start = time.time()
infer_times = []
for i in range(test_runs):
t0 = time.time()
predict(model, image, text_cache, box_threshold, text_threshold, is_benchmark=True)
infer_time = time.time() - t0
infer_times.append(infer_time)
print(f"实际推理 {i+1}/{test_runs} - 耗时: {infer_time*1000:.2f} ms")
total_test_time = time.time() - test_start
avg_infer_time = np.mean(infer_times)
std_infer_time = np.std(infer_times)
max_infer_time = np.max(infer_times)
min_infer_time = np.min(infer_times)
fps = test_runs / total_test_time
print("\n" + "="*60)
print("📈 性能测试报告(仅实际推理阶段)")
print("="*60)
print(f"测试次数: {test_runs} 次")
print(f"总推理耗时: {total_test_time:.3f} s")
print(f"平均推理耗时: {avg_infer_time*1000:.2f} ms (±{std_infer_time*1000:.2f} ms)")
print(f"最大推理耗时: {max_infer_time*1000:.2f} ms")
print(f"最小推理耗时: {min_infer_time*1000:.2f} ms")
print(f"平均FPS: {fps:.2f} 帧/秒")
print("="*60)
return {
"warmup_runs": warmup_runs,
"test_runs": test_runs,
"avg_infer_time_ms": avg_infer_time*1000,
"std_infer_time_ms": std_infer_time*1000,
"max_infer_time_ms": max_infer_time*1000,
"min_infer_time_ms": min_infer_time*1000,
"fps": fps
}
# ========================= # =========================
...@@ -228,31 +275,84 @@ def benchmark(model, tokenizer, image, caption, box_th, text_th, warmup=5, runs= ...@@ -228,31 +275,84 @@ def benchmark(model, tokenizer, image, caption, box_th, text_th, warmup=5, runs=
# ========================= # =========================
if __name__ == "__main__": if __name__ == "__main__":
model_path = "../weights/ground_opt.onnx" model_path = "../weights/ground_opt_0430.onnx"
cache_path = "../weights/ground_opt.mxr" # ⭐ 缓存文件 cache_path = "../weights/ground_opt_0430.mxr"
img_path = "../images/in/car_1.jpg" img_path = "../images/in/car_1.jpg"
TEXT_PROMPT = "car ."
BOX_TRESHOLD = 0.35 BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25 TEXT_TRESHOLD = 0.25
WARMUP_RUNS = 5
TEST_RUNS = 10
# 🚀 加载模型(自动缓存)
model = MIGraphXModel( model = MIGraphXModel(
model_path, model_path,
cache_path=cache_path, cache_path=cache_path,
force_recompile=False # 改成 True 可强制重编译 device_id=5,
force_recompile=False
) )
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
image_source, image = load_image(img_path) image_source, image = load_image(img_path)
benchmark(model, tokenizer, image, TEXT_PROMPT, BOX_TRESHOLD, TEXT_TRESHOLD) # =========================
# 提前计算得到的 Text Cache
# =========================
TEXT_CACHE = {
'input_ids': np.array([[ 101, 2482, 1012, 102]], dtype=np.int64),
'attention_mask': np.array([[ True, True, True, True]], dtype=np.bool_),
'position_ids': np.array([[0, 0, 1, 0]], dtype=np.int64),
'token_type_ids': np.array([[0, 0, 0, 0]], dtype=np.int64),
'text_token_mask': np.array([[[ True, False, False, False],
[False, True, True, False],
[False, True, True, False],
[False, False, False, True]]], dtype=np.bool_),
# 存放 ID 对应的单词,用于快速 decode
'tokens': ["[CLS]", "car", ".", "[SEP]"]
}
benchmark_performance(
model, image, TEXT_CACHE,
BOX_TRESHOLD, TEXT_TRESHOLD,
WARMUP_RUNS, TEST_RUNS
)
print("\n" + "="*60)
print("🎯 执行最终推理(带详细日志+保存结果)")
print("="*60)
# 传入 TEXT_CACHE
boxes, confs, phrases = predict( boxes, confs, phrases = predict(
model, tokenizer, image, model, image, TEXT_CACHE,
TEXT_PROMPT, BOX_TRESHOLD, TEXT_TRESHOLD BOX_TRESHOLD, TEXT_TRESHOLD
) )
print("检测结果:", phrases) print("\n🎯 执行最终推理并保存结果图")
\ No newline at end of file ori_img = cv2.imread(img_path)
img_h = ori_img.shape[0]
img_w = ori_img.shape[1]
for i in range(len(boxes)):
one_box = boxes[i]
one_conf = confs[i]
one_cls = phrases[i]
x1 = int((one_box[0] - one_box[2] / 2) * img_w)
y1 = int((one_box[1] - one_box[3] / 2) * img_h)
x2 = int((one_box[0] + one_box[2] / 2) * img_w)
y2 = int((one_box[1] + one_box[3] / 2) * img_h)
cv2.rectangle(ori_img, (x1, y1), (x2, y2), (0, 0, 255), 2)
# 此时打印的 one_cls 将是真实的类别名称(如 "car")
cv2.putText(
ori_img, f'{one_cls} {one_conf:.2f}',
(x1-15, y1-15),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
color=(255, 255, 255),
fontScale=1.5,
thickness=3
)
cv2.imwrite('../weights/result_migraphx.jpg', ori_img)
print(f"\n✅ 结果已保存至: ../weights/result_migraphx.jpg")
print(f"✅ 检测到目标: {phrases} (共 {len(boxes)} 个)")
\ No newline at end of file
import cv2 import cv2
import numpy as np import numpy as np
import torch
import time import time
import os import os
import migraphx import migraphx
from typing import Tuple from typing import Tuple
import torch
import groundingdino.datasets.transforms as T import groundingdino.datasets.transforms as T
from PIL import Image from PIL import Image
"""
使用cpu数据做推理
"""
def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]: def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
transform = T.Compose( transform = T.Compose(
[ [
...@@ -25,7 +29,43 @@ def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]: ...@@ -25,7 +29,43 @@ def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
def sigmoid(x): def sigmoid(x):
return 1 / (1 + np.exp(-x)) return 1 / (1 + np.exp(-x))
def get_phrases_from_posmap(
posmap: np.ndarray, tokens: List[str], left_idx: int = 0, right_idx: int = 255
):
"""
【核心优化】直接用字符串列表映射,抛弃沉重的 Tokenizer
"""
assert isinstance(posmap, np.ndarray), "posmap must be np.ndarray"
if posmap.ndim == 1:
# 将指定范围内的元素设为 False
posmap[:left_idx + 1] = False
posmap[right_idx:] = False
# 获取非零元素的索引
non_zero_idx = np.nonzero(posmap)[0]
# 提取被激活的单词,并自动过滤掉特殊占位符
words = [tokens[i] for i in non_zero_idx if tokens[i] not in ["[CLS]", "[SEP]", "."]]
return " ".join(words).strip()
else:
raise NotImplementedError("posmap must be 1-dim")
def preprocess_caption(caption: str) -> str:
result = caption.lower().strip()
if result.endswith("."):
return result
return result + "."
def to_mgx(x):
if x.dtype == np.int64:
return migraphx.argument(x.astype(np.int64))
elif x.dtype == np.bool_:
return migraphx.argument(x.astype(np.bool_))
else:
return migraphx.argument(x.astype(np.float32))
def _mgx_shape_to_numpy(shape): def _mgx_shape_to_numpy(shape):
# 将 migraphx input shape 映射到 numpy dtype + lens 以生成零填充张量
shape_str = str(shape) shape_str = str(shape)
if "int64_type" in shape_str: if "int64_type" in shape_str:
dtype = np.int64 dtype = np.int64
...@@ -43,200 +83,304 @@ def _mgx_shape_to_numpy(shape): ...@@ -43,200 +83,304 @@ def _mgx_shape_to_numpy(shape):
lens = list(shape.lens()) lens = list(shape.lens())
except Exception: except Exception:
lens = [] lens = []
# 优先用 dims,dims 为空时才退化到 lens
return dtype, (dims if len(dims) > 0 else lens) return dtype, (dims if len(dims) > 0 else lens)
# ========================= # =========================
# 🚀 MIGraphX 推理类(带缓存与生命周期管理 # 🚀 MIGraphX 推理类(带缓存)
# ========================= # =========================
class MIGraphXModel: class MIGraphXModel:
def __init__(self, onnx_path, cache_path="weights/ground_opt.mxr", force_recompile=False, device_id=0): def __init__(self, onnx_path, cache_path="../weights/ground_opt_0506.mxr", force_recompile=False):
self.cache_path = cache_path self.cache_path = cache_path
# ====== 优先加载缓存 ======
if os.path.exists(cache_path) and not force_recompile: if os.path.exists(cache_path) and not force_recompile:
print(f"⚡ 直接加载已编译模型: {cache_path}") print(f"⚡ 直接加载已编译模型: {cache_path}")
self.model = migraphx.load(cache_path) self.model = migraphx.load(cache_path)
else: else:
print("🔍 从 ONNX 构建 MIGraphX") print("🔍 从 ONNX 构建 MIGraphX")
self.model = migraphx.parse_onnx(onnx_path) self.model = migraphx.parse_onnx(onnx_path)
# print(self.model)
# ====================== 2. 打印模型输入输出信息 ======================
print("=== 模型输入信息 ===")
inputs = self.model.get_inputs()
for key, value in inputs.items():
print(f"{key}: {value}")
print(f"⚙️ 编译 MIGraphX(GPU {device_id})") print("\n=== 模型输出信息 ===")
self.model.compile(t=migraphx.get_target("gpu"), device_id=device_id) outputs = self.model.get_outputs()
for key, value in outputs.items():
print(f"{key}: {value}")
print("⚙️ 编译 MIGraphX(GPU)")
self.model.compile(
t=migraphx.get_target("gpu"), device_id=3, offload_copy=True
)
# ====== 保存缓存 ======
print(f"💾 保存编译模型到: {cache_path}") print(f"💾 保存编译模型到: {cache_path}")
migraphx.save(self.model, cache_path) migraphx.save(self.model, cache_path)
self.param_names = self.model.get_parameter_names()
self.input_shapes = self.model.get_inputs() self.input_shapes = self.model.get_inputs()
print("✅ param_names:", self.param_names)
print("✅ input_shape:", self.input_shapes)
try:
self.output_shapes = self.model.get_outputs()
print("✅ output_shapes keys:", list(self.output_shapes.keys()))
except Exception:
self.output_shapes = None
def infer(self, input_dict): def infer(self, input_dict):
# 只按模型 get_inputs() 定义的输入签名来组装
mgx_inputs = {} mgx_inputs = {}
# 【关键修复区】:用于保持 NumPy 数组存活,防止 Python 垃圾回收导致底层指针失效
self._keep_alive_cache = {}
provided_names = set(input_dict.keys()) provided_names = set(input_dict.keys())
# 某些 mxr 会把内部输出别名也暴露到 get_parameter_names/get_inputs 里,
# 这里显式排除 main:#output_*,避免把内部输出当成输入填充。
required_names = { required_names = {
k for k in self.input_shapes.keys() k for k in self.input_shapes.keys()
if not str(k).startswith("main:#output") if not str(k).startswith("main:#output")
} }
for name in required_names: missing = required_names - provided_names
shape = self.input_shapes[name] if missing:
target_dtype, lens = _mgx_shape_to_numpy(shape) print("⚠️ 缺失模型输入,准备按 shape 自动补齐:")
for name in sorted(missing):
if name in provided_names: shape = self.input_shapes[name]
# 1. 必须转为连续内存!防止 PyTorch 转过来的 array 内存步长不一致 dtype, lens = _mgx_shape_to_numpy(shape)
arr = np.ascontiguousarray(input_dict[name]) mgx_inputs[name] = to_mgx(np.zeros(lens, dtype=dtype))
# 2. 强制类型转换 print(f" - {name}: shape={lens}, dtype={dtype.__name__}")
if arr.dtype != target_dtype:
arr = arr.astype(target_dtype) for name in (required_names & provided_names):
else: mgx_inputs[name] = to_mgx(input_dict[name])
# 缺失的输入用 0 补齐
arr = np.zeros(lens, dtype=target_dtype) # 额外的 key 不喂给模型,避免和内部签名冲突
extra = provided_names - required_names
# 3. 将数组塞进字典,强行续命! if extra:
self._keep_alive_cache[name] = arr print("ℹ️ 有多余输入参数将被忽略:")
for name in sorted(extra):
# 4. 安全地将指针移交给 migraphx print(f" - {name}")
mgx_inputs[name] = migraphx.argument(arr)
start = time.time() start = time.time()
result = self.model.run(mgx_inputs) result = self.model.run(mgx_inputs)
infer_time = time.time() - start infer_time = time.time() - start
outputs = [np.array(r) for r in result] outputs = [np.array(r) for r in result]
# 推理结束,释放内存
self._keep_alive_cache.clear()
return outputs, infer_time return outputs, infer_time
# ========================= # =========================
# 推理函数 (硬编码输入,无 Tokenizer) # 推理函数
# ========================= # =========================
def predict(model, image, box_threshold, is_benchmark=False): def predict(
model,
image,
caption,
box_threshold,
text_threshold,
is_benchmark=False
):
# 提前针对car .生成对应输入
input_dict = { input_dict = {
"img": np.expand_dims(np.asarray(image), axis=0), "img": np.expand_dims(np.asarray(image), axis=0).astype(np.float32),
"position_ids": np.array([[0, 0, 1, 0]]), "position_ids": np.array([[0, 0, 1, 0]], dtype=np.int64),
"input_ids": np.array([[101, 2482, 1012, 102]]), "input_ids": np.array([[101, 2482, 1012, 102]], dtype=np.int64),
"token_type_ids": np.array([[0, 0, 0, 0]]), "token_type_ids": np.array([[0, 0, 0, 0]], dtype=np.int64),
"text_token_mask": np.array([[ "text_token_mask": np.array([[
[True, False, False, False], [True, False, False, False],
[False, True, True, False], [False, True, True, False],
[False, True, True, False], [False, True, True, False],
[False, False, False, True] [False, False, False, True]
]]), ]], dtype=np.bool_),
"attention_mask": np.array([[True, True, True, True]]) "attention_mask": np.array([[True, True, True, True]], dtype=np.bool_)
} }
outputs, infer_time = model.infer(input_dict) outputs, infer_time = model.infer(input_dict)
if not is_benchmark: if not is_benchmark:
print(f"Inference time: {infer_time*1000:.2f} ms") print(f"Inference time: {infer_time:.3f}s")
logits = sigmoid(outputs[0][0]) t0 = time.time()
boxes = outputs[1][0] prediction_logits = sigmoid(outputs[0][0])
prediction_boxes = outputs[1][0]
post_time = time.time() - t0
max_values = np.max(logits, axis=1) if not is_benchmark:
print(f"post time: {post_time:.3f}s")
print(f"\n=== Debug Info ===")
print(f"Prediction logits shape: {prediction_logits.shape}")
print(f"Prediction boxes shape: {prediction_boxes.shape}")
print(f"Max logit value: {np.max(prediction_logits):.4f}")
print(f"Mean logit value: {np.mean(prediction_logits):.4f}")
max_values = np.max(prediction_logits, axis=1)
mask = max_values > box_threshold mask = max_values > box_threshold
logits = logits[mask] logits = prediction_logits[mask]
boxes = boxes[mask] boxes = prediction_boxes[mask]
phrases = ["car"] * len(boxes) tokens = text_cache['tokens']
input_ids = text_cache['input_ids'][0].tolist()
if remove_combined:
sep_idx = [i for i in range(len(input_ids)) if input_ids[i] in [101, 102, 1012]]
phrases = []
for logit in logits:
max_idx = logit.argmax()
insert_idx = bisect.bisect_left(sep_idx, max_idx)
right_idx = sep_idx[insert_idx]
left_idx = sep_idx[insert_idx - 1]
phrases.append(
get_phrases_from_posmap(logit > text_threshold, tokens, left_idx, right_idx)
)
else:
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokens)
for logit in logits
]
return boxes, np.max(logits, axis=1), phrases return boxes, np.max(logits, axis=1), phrases
# ========================= # =========================
# Benchmark # Benchmark (完全移植 ORT 格式)
# ========================= # =========================
def benchmark(model, image, box_th, warmup=5, runs=10): def benchmark_performance(
print("\n🔥 预热") model, image, caption, box_threshold, text_threshold,
for _ in range(warmup): warmup_runs=5, test_runs=10
predict(model, image, box_th, True) ):
"""
print("\n🚀 测试") 性能测试函数:包含预热和实际推理
times = [] """
for i in range(runs): print("="*60)
start = time.time() print("📊 开始性能测试(包含预热+实际推理)")
predict(model, image, box_th, True) print("="*60)
times.append(time.time() - start)
print(f"\n🔥 预热阶段({warmup_runs} 次)- 不计入性能统计")
warmup_start = time.time()
for i in range(warmup_runs):
t0 = time.time()
predict(model, image, caption, box_threshold, text_threshold, is_benchmark=True)
warmup_time = time.time() - t0
print(f"预热 {i+1}/{warmup_runs} - 耗时: {warmup_time*1000:.2f} ms")
total_warmup_time = time.time() - warmup_start
print(f"\n预热完成 - 总耗时: {total_warmup_time:.3f} s, 平均每次: {total_warmup_time/warmup_runs*1000:.2f} ms")
print(f"\n🚀 实际推理测试阶段({test_runs} 次)- 统计性能指标")
test_start = time.time()
infer_times = []
for i in range(test_runs):
t0 = time.time()
predict(model, image, caption, box_threshold, text_threshold, is_benchmark=True)
infer_time = time.time() - t0
infer_times.append(infer_time)
print(f"实际推理 {i+1}/{test_runs} - 耗时: {infer_time*1000:.2f} ms")
# 计算性能指标
total_test_time = time.time() - test_start
avg_infer_time = np.mean(infer_times)
std_infer_time = np.std(infer_times)
max_infer_time = np.max(infer_times)
min_infer_time = np.min(infer_times)
fps = test_runs / total_test_time
# 输出性能报告
print("\n" + "="*60)
print("📈 性能测试报告(仅实际推理阶段)")
print("="*60)
print(f"测试次数: {test_runs} 次")
print(f"总推理耗时: {total_test_time:.3f} s")
print(f"平均推理耗时: {avg_infer_time*1000:.2f} ms (±{std_infer_time*1000:.2f} ms)")
print(f"最大推理耗时: {max_infer_time*1000:.2f} ms")
print(f"最小推理耗时: {min_infer_time*1000:.2f} ms")
print(f"平均FPS: {fps:.2f} 帧/秒")
print("="*60)
print(f"\n平均耗时: {np.mean(times)*1000:.2f} ms") return {
print(f"FPS: {1/np.mean(times):.2f}") "warmup_runs": warmup_runs,
"test_runs": test_runs,
"avg_infer_time_ms": avg_infer_time*1000,
"std_infer_time_ms": std_infer_time*1000,
"max_infer_time_ms": max_infer_time*1000,
"min_infer_time_ms": min_infer_time*1000,
"fps": fps
}
# ========================= # =========================
# 主函数 # 主函数
# ========================= # =========================
# if __name__ == "__main__": if __name__ == "__main__":
# model_path = "../weights/ground_opt.onnx"
# cache_path = "../weights/ground_opt.mxr"
# img_path = "../images/in/car_1.jpg"
# BOX_TRESHOLD = 0.35
# DEVICE_ID = 5 # 匹配你之前报错堆栈里的 device: 5 / 0 的情况,按需修改
# model = MIGraphXModel( model_path = "../weights/ground_opt_0430.onnx"
# model_path, cache_path = "../weights/ground_opt_0506.mxr" # ⭐ 缓存文件
# cache_path=cache_path,
# force_recompile=False,
# device_id=DEVICE_ID
# )
# image_source, image = load_image(img_path) img_path = "../images/in/car_1.jpg"
# benchmark(model, image, BOX_TRESHOLD) TEXT_PROMPT = "car ."
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
WARMUP_RUNS = 5
TEST_RUNS = 10
# 🚀 加载模型(自动缓存)
model = MIGraphXModel(
model_path,
cache_path=cache_path,
force_recompile=False # 改成 True 可强制重编译
)
# boxes, confs, phrases = predict(model, image, BOX_TRESHOLD) image_source, image = load_image(img_path)
# print("检测结果:", phrases) # 第一步:运行完整的性能测试(预热+实际推理)
benchmark_performance(
model, image, TEXT_PROMPT,
BOX_TRESHOLD, TEXT_TRESHOLD,
WARMUP_RUNS, TEST_RUNS
)
def test_like_perf(model): # 第二步:执行最终推理并画图保存
print("\n" + "="*60) print("\n" + "="*60)
print("🛠️ 模拟 perf 工具:生成完美对齐的 Dummy 数据测试") print("🎯 执行最终推理(带详细日志+保存结果)")
print("="*60) print("="*60)
mgx_inputs = {} boxes, confs, phrases = predict(
keep_alive_cache = [] # 强行续命池 model, image,
TEXT_PROMPT, BOX_TRESHOLD, TEXT_TRESHOLD
)
# 绘制并保存结果图片
print("\n🎯 执行最终推理并保存结果图")
ori_img = cv2.imread(img_path)
img_h = ori_img.shape[0]
img_w = ori_img.shape[1]
# 1. 严格按照模型要求的形状造假数据 for i in range(len(boxes)):
for name, shape in model.get_inputs().items(): one_box = boxes[i]
if str(name).startswith("main:#output"): one_conf = confs[i]
continue one_cls = phrases[i]
# 解析真实需要的类型和形状
target_dtype, lens = _mgx_shape_to_numpy(shape)
print(f" 📦 分配 {name}: shape={lens}, dtype={target_dtype.__name__}")
# 生成分毫不差的全零矩阵(完美模拟 migraphx-driver)
dummy_data = np.zeros(lens, dtype=target_dtype)
keep_alive_cache.append(dummy_data)
# 移交指针 x1 = int((one_box[0] - one_box[2] / 2) * img_w)
mgx_inputs[name] = migraphx.argument(dummy_data) y1 = int((one_box[1] - one_box[3] / 2) * img_h)
x2 = int((one_box[0] + one_box[2] / 2) * img_w)
y2 = int((one_box[1] + one_box[3] / 2) * img_h)
print("\n🚀 开始 Dummy 推理测试...") cv2.rectangle(ori_img, (x1, y1), (x2, y2), (0, 0, 255), 2)
try: cv2.putText(
start = time.time() ori_img, f'{one_cls} {one_conf:.2f}',
model.run(mgx_inputs) (x1-15, y1-15),
print(f"✅ Python 端 Dummy 推理成功!没有任何 VMFault!耗时: {(time.time()-start)*1000:.2f}ms") fontFace=cv2.FONT_HERSHEY_SIMPLEX,
except Exception as e: color=(255, 255, 255),
print(f"❌ 依然报错: {e}") fontScale=1.5,
thickness=3
# ------------------ )
# 在主函数里这样调用:
# ------------------ # 保存结果
if __name__ == "__main__": cv2.imwrite('../weights/result_migraphx.jpg', ori_img)
model_path = "../weights/ground_opt.onnx" print(f"\n✅ 结果已保存至: ../weights/result_migraphx.jpg")
cache_path = "../weights/ground_opt.mxr" print(f"✅ 检测到目标: {phrases} (共 {len(boxes)} 个)")
\ No newline at end of file
model = migraphx.load(cache_path) # 直接加载你确定没问题的 mxr
# 运行模拟测试
test_like_perf(model)
\ No newline at end of file
import cv2
import numpy as np
import migraphx
"""
本示例演示了如何使用migraphx进行推理,主要步骤如下:
1. 加载模型
2. 获取模型输入输出节点信息
3. 编译模型
4. 为输出节点分配device内存,用于保存输出数据
5. 预处理并转换为NCHW
6. 将输入数据转换为device数据作为输入数据
7. 推理
"""
def ReadImage(pathOfImage,inputShape):
srcImage = cv2.imread(pathOfImage, cv2.IMREAD_COLOR)
# resize并转换为CHW
resizedImage = cv2.resize(srcImage,(inputShape[3], inputShape[2]))
resizedImage_Float = resizedImage.astype("float32") # 转换为float32
srcImage_CHW = np.transpose(resizedImage_Float, (2, 0, 1)) # 转换为CHW
# 预处理
mean = np.array([127.5, 127.5, 127.5])
scale = np.array([0.0078125, 0.0078125, 0.0078125])
inputData = np.zeros(inputShape).astype("float32") # NCHW
for i in range(srcImage_CHW.shape[0]):
inputData[0,i, :, :] = (srcImage_CHW[i, :, :] - mean[i]) * scale[i]
for i in range(inputData.shape[0]):
if i!=0:
inputData[i,:, :, :]=inputData[0,:, :, :]
return inputData
def AllocateOutputMemory(model):
outputData={}
for key in model.get_outputs().keys():
outputData[key] = migraphx.allocate_gpu(s=model.get_outputs()[key])
return outputData
if __name__ == '__main__':
# 加载模型
model = migraphx.parse_onnx("ResNet50.onnx")
# 获取模型输入输出节点信息
print("inputs:")
inputs=model.get_inputs()
for key,value in inputs.items():
print("{}:{}".format(key,value))
print("outputs:")
outputs=model.get_outputs()
for key,value in outputs.items():
print("{}:{}".format(key,value))
inputName=list(model.get_inputs().keys())[0]
inputShape=inputs[inputName].lens()
# 编译
model.compile(t=migraphx.get_target("gpu"),offload_copy=False,device_id=0)
# 为输出节点分配device内存,用于保存输出数据
modelData=AllocateOutputMemory(model)
# 预处理并转换为NCHW
pathOfImage ="Test.jpg"
image = ReadImage(pathOfImage,inputShape)
# 将输入数据转换为device数据作为输入数据
modelData[inputName]=migraphx.to_gpu(migraphx.argument(image))
# 推理
results = model.run(modelData)
# 获取输出节点属性
result=migraphx.from_gpu(results[0]) # 将第一个输出节点的数据拷贝到host端,migraphx.argument类型
outputShape=result.get_shape() # 输出节点的shape,migraphx.shape类型
outputSize=outputShape.lens() # 每一维大小,维度顺序为(N,C,H,W),list类型
numberOfOutput=outputShape.elements() # 输出节点元素的个数
# 转换为numpy
result = np.array(result)
print(result)
migraphx-driver perf --batch 1 \ migraphx-driver perf --batch 1 \
-n 10 \ -n 10 \
--fp16 \ --fp16 \
--migraphx ../weights/ground_opt.mxr --migraphx ../weights/ground_opt_0430.mxr
\ No newline at end of file \ No newline at end of file
...@@ -30,123 +30,37 @@ def change_inf_to_value(om: ONNXModifier): ...@@ -30,123 +30,37 @@ def change_inf_to_value(om: ONNXModifier):
records.add(init_name) records.add(init_name)
# def optimize_where_ndoes(om: ONNXModifier):
# """Where节点等价替换
# (1) condition为initializer, X为0, Y为输入数据:
# Where(cond, X, Y) ==> Mul(Y, ~cond)
# (2) condition为initializer, X为负无穷, Y为输入数据
# Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
# (3) condition为真实输入, X为负无穷, Y为输入数据
# Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
# cases:
# 1. Where(cond, -inf, input)
# a. /transformer/encoder/fusion_layers.*/attn/Where
# b. /transformer/encoder/fusion_layers.*/attn/Where_1
# c. /class_embed.0_*/Where: Where(cond, -inf, input)
# 2. Where(cond, 0, input):
# a. /transformer/encoder/layers.*/self_attn/Where
# b. /transformer/decoder/layers.*/cross_attn/Where
# """
# for where_node in om.get_nodes("Where"):
# where_name = where_node.name
# # print("Process where node:", where_name)
# x_value = om.get_initializer_value(where_node.inputs[1])
# assert x_value.size == 1
# assert x_value == np.array(0.0, dtype=np.float32) or \
# x_value == np.array(-np.inf, dtype=np.float32)
# cond_init = om.get_initializer(where_node.inputs[0])
# if cond_init is not None:
# cond_value = om.get_initializer_value(where_node.inputs[0])
# if x_value == np.array(0.0, dtype=np.float32):
# # Where(cond, X, Y) ==> Mul(Y, ~cond)
# mul_name = where_name.replace("Where", "NewMul")
# mul_b_init = om.create_initializer(mul_name + "_B",
# (~cond_value).astype(np.float32))
# mul_node = om.create_node("Mul",
# mul_name,
# [where_node.inputs[2], mul_b_init.name],
# [mul_name+"_output_0"],
# index=where_node.index)
# next_nodes = where_node.next_nodes
# for next_node in next_nodes:
# next_node.replace_input(where_node.outputs[0], mul_node.outputs[0])
# elif x_value == np.array(-np.inf, dtype=np.float32):
# # Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
# sub_name = where_name.replace("Where", "NewSub")
# sub_b_init = om.create_initializer(
# sub_name + "_B",
# np.where(cond_value.astype(np.float32),
# np.finfo(np.float16).max, 0.0).astype(np.float32)
# )
# sub_node = om.create_node("Sub",
# sub_name,
# [where_node.inputs[2], sub_b_init.name],
# [sub_name+"_output_0"],
# index=where_node.index)
# next_nodes = where_node.next_nodes
# for next_node in next_nodes:
# next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
# else:
# # Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
# assert x_value == np.array(-np.inf, dtype=np.float32)
# cast_name = where_name.replace("Where", "NewCast")
# mul_name = where_name.replace("Where", "NewMul")
# sub_name = where_name.replace("Where", "NewSub")
# cast_node = om.create_node("Cast",
# cast_name,
# [where_node.inputs[0]],
# [cast_name+"_output_0"],
# to=1,
# index=where_node.index)
# mul_b_init = om.create_initializer(mul_name + "_B",
# np.array([np.finfo(np.float16).max], np.float32))
# mul_node = om.create_node("Mul",
# mul_name,
# [cast_node.outputs[0], mul_b_init.name],
# [mul_name+"_output_0"],
# index=cast_node.index+1)
# sub_node = om.create_node("Sub",
# sub_name,
# [where_node.inputs[2], mul_node.outputs[0]],
# [sub_name+"_output_0"],
# index=mul_node.index+1)
# next_nodes = where_node.next_nodes
# for next_node in next_nodes:
# next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
# om.update_map()
def optimize_where_ndoes(om: ONNXModifier): def optimize_where_ndoes(om: ONNXModifier):
"""Where节点等价替换 (加入安全校验版本)""" """Where节点等价替换
(1) condition为initializer, X为0, Y为输入数据:
Where(cond, X, Y) ==> Mul(Y, ~cond)
(2) condition为initializer, X为负无穷, Y为输入数据
Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
(3) condition为真实输入, X为负无穷, Y为输入数据
Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
cases:
1. Where(cond, -inf, input)
a. /transformer/encoder/fusion_layers.*/attn/Where
b. /transformer/encoder/fusion_layers.*/attn/Where_1
c. /class_embed.0_*/Where: Where(cond, -inf, input)
2. Where(cond, 0, input):
a. /transformer/encoder/layers.*/self_attn/Where
b. /transformer/decoder/layers.*/cross_attn/Where
"""
for where_node in om.get_nodes("Where"): for where_node in om.get_nodes("Where"):
where_name = where_node.name where_name = where_node.name
# print("Process where node:", where_name)
# 1. 安全获取 X 的值,如果 X 不是常量(initializer),直接跳过不优化
x_init = om.get_initializer(where_node.inputs[1])
if x_init is None:
continue
x_value = om.get_initializer_value(where_node.inputs[1]) x_value = om.get_initializer_value(where_node.inputs[1])
assert x_value.size == 1
# 2. 避免 assert 崩溃:如果 size 不为 1,说明不是我们要找的 Attention Mask 节点,跳过 assert x_value == np.array(0.0, dtype=np.float32) or \
if x_value.size != 1: x_value == np.array(-np.inf, dtype=np.float32)
continue
# 3. 判断是否符合优化条件(0.0 或 -inf),不符合直接跳过
is_zero = (x_value == np.array(0.0, dtype=np.float32))
is_neg_inf = (x_value == np.array(-np.inf, dtype=np.float32))
if not (is_zero or is_neg_inf):
continue
cond_init = om.get_initializer(where_node.inputs[0]) cond_init = om.get_initializer(where_node.inputs[0])
if cond_init is not None: if cond_init is not None:
cond_value = om.get_initializer_value(where_node.inputs[0]) cond_value = om.get_initializer_value(where_node.inputs[0])
if is_zero: if x_value == np.array(0.0, dtype=np.float32):
# Where(cond, X, Y) ==> Mul(Y, ~cond) # Where(cond, X, Y) ==> Mul(Y, ~cond)
mul_name = where_name.replace("Where", "NewMul") mul_name = where_name.replace("Where", "NewMul")
mul_b_init = om.create_initializer(mul_name + "_B", mul_b_init = om.create_initializer(mul_name + "_B",
...@@ -159,7 +73,7 @@ def optimize_where_ndoes(om: ONNXModifier): ...@@ -159,7 +73,7 @@ def optimize_where_ndoes(om: ONNXModifier):
next_nodes = where_node.next_nodes next_nodes = where_node.next_nodes
for next_node in next_nodes: for next_node in next_nodes:
next_node.replace_input(where_node.outputs[0], mul_node.outputs[0]) next_node.replace_input(where_node.outputs[0], mul_node.outputs[0])
elif is_neg_inf: elif x_value == np.array(-np.inf, dtype=np.float32):
# Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0)) # Where(cond, X, Y) ==> Sub(Y, Where(cond, np.inf, 0))
sub_name = where_name.replace("Where", "NewSub") sub_name = where_name.replace("Where", "NewSub")
sub_b_init = om.create_initializer( sub_b_init = om.create_initializer(
...@@ -177,10 +91,7 @@ def optimize_where_ndoes(om: ONNXModifier): ...@@ -177,10 +91,7 @@ def optimize_where_ndoes(om: ONNXModifier):
next_node.replace_input(where_node.outputs[0], sub_node.outputs[0]) next_node.replace_input(where_node.outputs[0], sub_node.outputs[0])
else: else:
# Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf)) # Where(cond, X, Y) ==> Sub(Y, Mul(Cast(cond, to=float32), np.inf))
# 当 condition 不是 initializer 时,只处理 -inf 的情况 assert x_value == np.array(-np.inf, dtype=np.float32)
if not is_neg_inf:
continue
cast_name = where_name.replace("Where", "NewCast") cast_name = where_name.replace("Where", "NewCast")
mul_name = where_name.replace("Where", "NewMul") mul_name = where_name.replace("Where", "NewMul")
sub_name = where_name.replace("Where", "NewSub") sub_name = where_name.replace("Where", "NewSub")
...@@ -208,6 +119,7 @@ def optimize_where_ndoes(om: ONNXModifier): ...@@ -208,6 +119,7 @@ def optimize_where_ndoes(om: ONNXModifier):
om.update_map() om.update_map()
def optimize_transpose_nodes(om: ONNXModifier): def optimize_transpose_nodes(om: ONNXModifier):
transpose_list = [ transpose_list = [
"/transformer/encoder/Transpose", "/transformer/encoder/Transpose",
...@@ -256,64 +168,50 @@ def optimize_transpose_nodes(om: ONNXModifier): ...@@ -256,64 +168,50 @@ def optimize_transpose_nodes(om: ONNXModifier):
] ]
for name in transpose_list: for name in transpose_list:
node = om.get_node(name) node = om.get_node(name)
# 安全校验:如果找不到这个节点,说明当前模型不需要优化这个点,跳过 assert node.attrs['perm'] == [1, 0 , 2] or node.attrs['perm'] == [1, 0 , 2, 3], \
if node is None: f"perm={node.attrs['perm']}"
continue next_nodes = om.get_next_nodes(node)
for node_ in next_nodes:
if 'perm' in node.attrs and (node.attrs['perm'] == [1, 0 , 2] or node.attrs['perm'] == [1, 0 , 2, 3]): node_.replace_input(node.outputs[0], node.inputs[0])
next_nodes = om.get_next_nodes(node)
for node_ in next_nodes:
node_.replace_input(node.outputs[0], node.inputs[0])
# modify /transformer/encoder/text_layers.*/self_attn/Reshape_4 # modify /transformer/encoder/text_layers.*/self_attn/Reshape_4
# om.set_initializer_value("_v_8735", np.array([-1, 4, 256], np.int64))
shape_init1 = om.create_initializer( shape_init1 = om.create_initializer(
"/transformer/encoder/text_layers.x/self_attn/des_shape", "/transformer/encoder/text_layers.x/self_attn/des_shape",
np.array([1, 4, 256], np.int64) np.array([1, 4, 256], np.int64)
) )
for i in range(6): for i in range(6):
reshape_node = om.get_node(f"/transformer/encoder/text_layers.{i}/self_attn/Reshape_4") reshape_node = om.get_node(f"/transformer/encoder/text_layers.{i}/self_attn/Reshape_4")
if reshape_node is not None: reshape_node.set_input(1, shape_init1.name)
reshape_node.set_input(1, shape_init1.name)
# modify /transformer/enc_out_class_embed/Transpose # modify /transformer/enc_out_class_embed/Transpose
trans_node = om.get_node("/transformer/enc_out_class_embed/Transpose") om.get_node("/transformer/enc_out_class_embed/Transpose").set_attribute("perm", [0, 2, 1])
if trans_node is not None:
trans_node.set_attribute("perm", [0, 2, 1])
# modify /transformer/decoder/Reshape_* # 安全校验:避免写死的随机变量名 _v_5525 引发崩溃 # modify /transformer/decoder/Reshape_*
init_5525 = om.get_initializer("_v_5525") om.set_initializer_value("_v_5525", np.array([1, 900, -1], np.int64))
if init_5525 is not None:
om.set_initializer_value("_v_5525", np.array([1, 900, -1], np.int64))
# modify /transformer/decoder/layers.*/self_attn/Reshape_4 # modify /transformer/decoder/layers.*/self_attn/Reshape_4
# modify /transformer/decoder/layers.*/ca_text/Reshape_6 # modify /transformer/decoder/layers.*/ca_text/Reshape_6
# om.set_initializer_value("_v_6230", np.array([-1, 900, 256], np.int64))
shape_init3 = om.create_initializer( shape_init3 = om.create_initializer(
"/transformer/decoder/layers.x/self_attn_ca_text/des_shape", "/transformer/decoder/layers.x/self_attn_ca_text/des_shape",
np.array([1, 900, 256], np.int64) np.array([1, 900, 256], np.int64)
) )
for i in range(6): for i in range(6):
reshape_node1 = om.get_node(f"/transformer/decoder/layers.{i}/self_attn/Reshape_4") reshape_node1 = om.get_node(f"/transformer/decoder/layers.{i}/self_attn/Reshape_4")
if reshape_node1 is not None: reshape_node1.set_input(1, shape_init3.name)
reshape_node1.set_input(1, shape_init3.name)
reshape_node2 = om.get_node(f"/transformer/decoder/layers.{i}/ca_text/Reshape_6") reshape_node2 = om.get_node(f"/transformer/decoder/layers.{i}/ca_text/Reshape_6")
if reshape_node2 is not None: reshape_node2.set_input(1, shape_init3.name)
reshape_node2.set_input(1, shape_init3.name)
# modify /transformer/decoder/layers.0/Add # modify /transformer/decoder/layers.0/Add
# modify /transformer/decoder/layers.0/Add_1 # modify /transformer/decoder/layers.0/Add_1
init_name = "/transformer/Tile_1_output_0" init_name = "/transformer/Tile_1_output_0"
tile_init = om.get_initializer(init_name) add_value = om.get_initializer_value(init_name)
if tile_init is not None: om.set_initializer_value(init_name, np.ascontiguousarray(add_value.transpose(1, 0, 2)))
add_value = om.get_initializer_value(init_name)
om.set_initializer_value(init_name, np.ascontiguousarray(add_value.transpose(1, 0, 2)))
om.update_map() om.update_map()
om.infer_shape()
# 将形状推断包起来,防止自定义算子(MSDeformAttn)导致推理失败崩溃
try:
om.infer_shape(strict_mode=False)
except Exception as e:
print(f"[Warning] infer_shape 跳过 (可能由于自定义算子引起). 详细信息: {e}")
def optmize_sin_cos_block(om: ONNXModifier): def optmize_sin_cos_block(om: ONNXModifier):
node_pairs = [ node_pairs = [
...@@ -325,87 +223,94 @@ def optmize_sin_cos_block(om: ONNXModifier): ...@@ -325,87 +223,94 @@ def optmize_sin_cos_block(om: ONNXModifier):
("/transformer/decoder/Gather_26", "/transformer/decoder/ref_point_head/layers.0_5/MatMul"), ("/transformer/decoder/Gather_26", "/transformer/decoder/ref_point_head/layers.0_5/MatMul"),
] ]
# 提前创建一些公用的 initializer unsqueeze_axes_init1 = om.create_initializer(
unsqueeze_axes_init1 = om.create_initializer("/transformer/decoder/sin_cos_block/unsqueeze_axes1", np.array([3, 4], np.int64)) "/transformer/decoder/sin_cos_block/unsqueeze_axes1",
slice_axes_init = om.create_initializer("/transformer/decoder/sin_cos_block/slice_axes", np.array([4], np.int64)) np.array([3, 4], np.int64)
slice_steps_init = om.create_initializer("/transformer/decoder/sin_cos_block/slice_steps", np.array([1], np.int64)) )
slice_starts_init1 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_starts1", np.array([0], np.int64)) slice_axes_init = om.create_initializer(
slice_ends_init1 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_ends1", np.array([1], np.int64)) "/transformer/decoder/sin_cos_block/slice_axes",
slice_starts_init2 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_steps2", np.array([1], np.int64)) np.array([4], np.int64)
slice_ends_init2 = om.create_initializer("/transformer/decoder/sin_cos_block/slice_ends2", np.array([2], np.int64)) )
reshape_init = om.create_initializer("/transformer/decoder/sin_cos_block/reshape_dst_shape", np.array([1, 900, -1], np.int64)) slice_steps_init = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_steps",
np.array([1], np.int64)
)
slice_starts_init1 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_starts1",
np.array([0], np.int64)
)
slice_ends_init1 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_ends1",
np.array([1], np.int64)
)
slice_starts_init2 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_steps2",
np.array([1], np.int64)
)
slice_ends_init2 = om.create_initializer(
"/transformer/decoder/sin_cos_block/slice_ends2",
np.array([2], np.int64)
)
reshape_init = om.create_initializer(
"/transformer/decoder/sin_cos_block/reshape_dst_shape",
np.array([1, 900, -1], np.int64)
)
for i, (gather_name, matmul_name) in enumerate(node_pairs): for i, (gather_name, matmul_name) in enumerate(node_pairs):
gather_node = om.get_node(gather_name) gather_node = om.get_node(gather_name)
matmul_node = om.get_node(matmul_name) next_node = om.get_next_nodes(gather_node)[0]
assert next_node.op_type == "Mul", f"{next_node.op_type} {next_node.name}"
# 【安全校验】:如果找不到这一对节点,说明不需要/无法优化这个 block,直接跳过 mul_init_value = om.get_initializer_value(next_node.inputs[1])
if gather_node is None or matmul_node is None: assert mul_init_value.size == 1
continue next_node = om.get_next_nodes(next_node)[0]
assert next_node.op_type == "Unsqueeze"
try: next_node.set_inputs([gather_node.inputs[0], unsqueeze_axes_init1.name])
next_node = om.get_next_nodes(gather_node)[0] next_node = om.get_next_nodes(next_node)[0]
if next_node.op_type != "Mul": continue assert next_node.op_type == "Div"
mul_init_value = om.get_initializer_value(next_node.inputs[1]) div_init_value = om.get_initializer_value(next_node.inputs[1])
if mul_init_value.size != 1: continue new_value = (div_init_value / mul_init_value).reshape(1, 1, 1, 64, 2)
new_init = om.create_initializer(next_node.name + "_B", new_value)
next_node = om.get_next_nodes(next_node)[0] next_node.set_input(1, new_init.name)
if next_node.op_type != "Unsqueeze": continue
next_node.set_inputs([gather_node.inputs[0], unsqueeze_axes_init1.name]) next_nodes = om.get_next_nodes(next_node)
assert len(next_nodes) == 2 and all(x.op_type == 'Slice' for x in next_nodes)
sin_node, cos_node = None, None
for j, slice_node in enumerate(next_nodes):
slice_node.set_inputs([slice_node.inputs[0],
slice_starts_init1.name if j == 0 else slice_starts_init2.name,
slice_ends_init1.name if j == 0 else slice_ends_init2.name,
slice_axes_init.name,
slice_steps_init.name])
next_node = om.get_next_nodes(slice_node)[0]
if next_node.op_type == "Sin":
sin_node = next_node
elif next_node.op_type == "Cos":
cos_node = next_node
else:
raise RuntimeError("match fail!")
next_node = om.get_next_nodes(next_node)[0] next_node = om.get_next_nodes(next_node)[0]
if next_node.op_type != "Div": continue assert next_node.op_type == "Unsqueeze"
div_init_value = om.get_initializer_value(next_node.inputs[1])
new_value = (div_init_value / mul_init_value).reshape(1, 1, 1, 64, 2)
new_init = om.create_initializer(next_node.name + "_B", new_value)
next_node.set_input(1, new_init.name)
next_nodes = om.get_next_nodes(next_node)
if len(next_nodes) != 2 or not all(x.op_type == 'Slice' for x in next_nodes): continue
sin_node, cos_node = None, None
for j, slice_node in enumerate(next_nodes):
slice_node.set_inputs([slice_node.inputs[0],
slice_starts_init1.name if j == 0 else slice_starts_init2.name,
slice_ends_init1.name if j == 0 else slice_ends_init2.name,
slice_axes_init.name,
slice_steps_init.name])
n_node = om.get_next_nodes(slice_node)[0]
if n_node.op_type == "Sin":
sin_node = n_node
elif n_node.op_type == "Cos":
cos_node = n_node
else:
raise RuntimeError("match fail!")
n_node = om.get_next_nodes(n_node)[0]
n_node = om.get_next_nodes(n_node)[0]
next_node = n_node # Concat node
if next_node.op_type != "Concat": continue
next_node.set_inputs([sin_node.outputs[0], cos_node.outputs[0]])
next_node.set_attribute("axis", 4)
next_node = om.get_next_nodes(next_node)[0] next_node = om.get_next_nodes(next_node)[0]
if next_node.op_type != "Reshape": continue
next_node.set_input(1, reshape_init.name) assert next_node.op_type == "Concat"
next_node.set_inputs([sin_node.outputs[0], cos_node.outputs[0]])
matmul_node.set_input(0, next_node.outputs[0]) next_node.set_attribute("axis", 4)
if i == 0: next_node = om.get_next_nodes(next_node)[0]
mm_b_value = om.get_initializer_value(matmul_node.inputs[1]) assert next_node.op_type == "Reshape"
mm_b_value = np.concatenate([mm_b_value[128:256, ...], next_node.set_input(1, reshape_init.name)
mm_b_value[0:128, ...],
mm_b_value[256:, ...]], matmul_node = om.get_node(matmul_name)
axis=0) matmul_node.set_input(0, next_node.outputs[0])
om.set_initializer_value(matmul_node.inputs[1], mm_b_value) if i == 0:
except Exception as e: mm_b_value = om.get_initializer_value(matmul_node.inputs[1])
# 如果匹配过程中发生任何形状或节点断层的意外,静默跳过这个 block mm_b_value = np.concatenate([mm_b_value[128:256, ...],
continue mm_b_value[0:128, ...],
mm_b_value[256:, ...]],
axis=0)
om.set_initializer_value(matmul_node.inputs[1], mm_b_value)
om.update_map() om.update_map()
try: om.infer_shape()
om.infer_shape(strict_mode=False)
except:
pass
def fuse_one_attention(om: ONNXModifier, softmax_name: str, new_mask: bool = None, num_heads: int = 12): def fuse_one_attention(om: ONNXModifier, softmax_name: str, new_mask: bool = None, num_heads: int = 12):
...@@ -523,7 +428,7 @@ def optimize_normal_attention(om: ONNXModifier): ...@@ -523,7 +428,7 @@ def optimize_normal_attention(om: ONNXModifier):
# fuse_one_attention(om, f"/transformer/encoder/text_layers.{i}/self_attn/Softmax", "text_token_mask", num_heads=4) # fuse_one_attention(om, f"/transformer/encoder/text_layers.{i}/self_attn/Softmax", "text_token_mask", num_heads=4)
# /transformer/decoder # /transformer/decoder
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/self_attn/Softmax", new_mask, num_heads=8) fuse_one_attention(om, f"/transformer/decoder/layers.{i}/self_attn/Softmax", new_mask, num_heads=8)
fuse_one_attention(om, f"/transformer/decoder/layers.{i}/ca_text/Softmax", new_mask, num_heads=8) # fuse_one_attention(om, f"/transformer/decoder/layers.{i}/ca_text/Softmax", new_mask, num_heads=8)
om.update_map() om.update_map()
...@@ -616,22 +521,59 @@ def optimize_backbone_attention(om: ONNXModifier): ...@@ -616,22 +521,59 @@ def optimize_backbone_attention(om: ONNXModifier):
_fuse_one_attention(f"/backbone/backbone.0/layers.{l}/blocks.{b}/attn/softmax/Softmax") _fuse_one_attention(f"/backbone/backbone.0/layers.{l}/blocks.{b}/attn/softmax/Softmax")
def optimize_ms_deform_attn(om: ONNXModifier):
def fuse_ms_deform_attn(value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, output):
value_next_node = om.get_to_nodes(value)[0]
index = value_next_node.index
name_prefix = '/'.join(value.split('/')[:-1])
node_name = f"{name_prefix}/MSDeformAttn"
fusion_node = om.create_node("MSDeformAttn",
node_name,
[value, spatial_shapes, level_start_index,
sampling_locations, attention_weights],
[f"{node_name}_output_0"],
index=index)
next_nodes = om.get_to_nodes(output)
for node in next_nodes:
node.replace_input(output, fusion_node.outputs[0])
spatial_shapes_int = om.create_initializer(
"/transformer/spatial_shapes",
np.array([(100, 150), (50, 75), (25, 38), (13, 19)], dtype=np.int64)
)
level_start_index_init = om.create_initializer(
"/transformer/level_start_index",
np.array([0, 15000, 18750, 19700], dtype=np.int64)
)
for i in range(6):
fuse_ms_deform_attn(
f"/transformer/encoder/layers.{i}/self_attn/Reshape_output_0",
spatial_shapes_int.name,
level_start_index_init.name,
f"/transformer/encoder/layers.{i}/self_attn/Add_output_0",
f"/transformer/encoder/layers.{i}/self_attn/Reshape_3_output_0",
f"/transformer/encoder/layers.{i}/self_attn/Transpose_9_output_0"
)
fuse_ms_deform_attn(
f"/transformer/decoder/layers.{i}/cross_attn/Reshape_output_0",
spatial_shapes_int.name,
level_start_index_init.name,
f"/transformer/decoder/layers.{i}/cross_attn/Add_output_0",
f"/transformer/decoder/layers.{i}/cross_attn/Reshape_3_output_0",
f"/transformer/decoder/layers.{i}/cross_attn/Transpose_9_output_0"
)
om.update_map()
def optimize_bidirect_attention(om: ONNXModifier): def optimize_bidirect_attention(om: ONNXModifier):
for i in range(6): for i in range(6):
reduce_max_name = f"/transformer/encoder/fusion_layers.{i}/attn/ReduceMax_1" reduce_max_name = f"/transformer/encoder/fusion_layers.{i}/attn/ReduceMax_1"
reduce_max_node = om.get_node(reduce_max_name) reduce_max_node = om.get_node(reduce_max_name)
next_node = om.get_next_nodes(reduce_max_node)[0]
# 【安全校验】 assert next_node.op_type == "Sub"
if reduce_max_node is None:
continue
next_nodes = om.get_next_nodes(reduce_max_node)
if not next_nodes:
continue
next_node = next_nodes[0]
if next_node.op_type != "Sub":
continue
name_prefix = '/'.join(reduce_max_name.split('/')[:-1]) name_prefix = '/'.join(reduce_max_name.split('/')[:-1])
matmul_name = f"{name_prefix}/identity_MatMul" matmul_name = f"{name_prefix}/identity_MatMul"
...@@ -646,48 +588,44 @@ def optimize_bidirect_attention(om: ONNXModifier): ...@@ -646,48 +588,44 @@ def optimize_bidirect_attention(om: ONNXModifier):
) )
next_node.set_input(1, matmul_node.outputs[0]) next_node.set_input(1, matmul_node.outputs[0])
# def main():
# input_onnx_path = sys.argv[1]
# output_onnx_path = sys.argv[2]
# # input_onnx_path = "ground_sim.onnx"
# # output_onnx_path = "ground_sim_0424_2nd.onnx"
# om = ONNXModifier(input_onnx_path)
# optimize_where_ndoes(om) # 1. 替换where节点
# optimize_transpose_nodes(om) # 2. 优化transpose节点
# optmize_sin_cos_block(om) # 3. 优化位置编码
# # om.add_opset_import("com.microsoft", 1)
# # optimize_normal_attention(om) # 4. 融合bert、transformer中的mha
# # optimize_ms_deform_attn(om) # 5. 融合多尺度可变形注意力
# # optimize_backbone_attention(om) # 6. 融合backbone中的注意力
# optimize_bidirect_attention(om) # 7. 优化双向注意力
# om.save(output_onnx_path, save_as_external_data=False)
def main(): def optimize_clip_ndoes(om: ONNXModifier):
# 假设你的原始模型路径 """优化串联的两个clip: clip(min)->clip(max)"""
input_onnx_path = "../weights/ground_deform_sim.onnx" pass
# 优化后的模型输出路径
output_onnx_path = "../weights_opt/ground_deform_opt.onnx"
print(f"Loading ONNX model from {input_onnx_path}...")
om = ONNXModifier(input_onnx_path) def optimize_gemm_nodes(om: ONNXModifier):
"""
print("1. Optimizing Where nodes (Crucial for FP16 & MIGraphX)...") input_data
optimize_where_ndoes(om) / | \
mm1 mm2 mm3
print("2. Optimizing Transpose nodes...") """
optimize_transpose_nodes(om) def find_parallel_gemm_nodes():
pass
# print("3. Optimizing Sin/Cos positional encoding...") def merge_parallel_gemm_nodes(gemm_nodes):
# optmize_sin_cos_block(om) pass
# print("4. Optimizing Bidirectional attention...") pass
# optimize_bidirect_attention(om)
def main():
input_onnx_path = sys.argv[1]
output_onnx_path = sys.argv[2]
# input_onnx_path = "ground_sim.onnx"
# output_onnx_path = "ground_sim_0430.onnx"
print(f"Saving optimized model to {output_onnx_path}...") om = ONNXModifier(input_onnx_path)
optimize_where_ndoes(om) # 1. 替换where节点
optimize_transpose_nodes(om) # 2. 优化transpose节点
optmize_sin_cos_block(om) # 3. 优化位置编码
om.add_opset_import("com.microsoft", 1)
optimize_normal_attention(om) # 4. 融合bert、transformer中的mha
# optimize_backbone_attention(om) # 5. 融合backbone中的注意力
optimize_ms_deform_attn(om) # 6. 融合多尺度可变形注意力
optimize_bidirect_attention(om) # 7. 优化双向注意力
om.save(output_onnx_path, save_as_external_data=False) om.save(output_onnx_path, save_as_external_data=False)
print("Optimization Done!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment