export_onnx.py 3.41 KB
Newer Older
zk's avatar
zk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import onnx
from onnxsim import simplify

from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict

config_file = './groundingdino/config/GroundingDINO_SwinB_cfg.py'
checkpoint_path = './weights/groundingdino_swinb_cogcoor.pth'

def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
    args = SLConfig.fromfile(model_config_path)
    args.device = "cuda" if not cpu_only else "cpu"

    # modified config
    args.use_checkpoint = False
    args.use_transformer_ckpt = False

    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    _ = model.eval()
    return model

# 加载模型
model = load_model(config_file, checkpoint_path, cpu_only=True)

# 正式推理时使用的提示词,以及相关的mask
caption = "car ."
input_ids = model.tokenizer([caption], return_tensors="pt")["input_ids"]
position_ids = torch.tensor([[0, 0, 1, 0]])
token_type_ids = torch.tensor([[0, 0, 0, 0]])
attention_mask = torch.tensor([[True, True, True, True]])
text_token_mask = torch.tensor([[[True, False, False, False],
                                 [False,  True,  True,  False],
                                 [False,  True,  True,  False],
                                 [False,  False, False, True]]])

# 固定输入分辨率
img = torch.randn(1, 3, 800, 1200)
zk's avatar
zk committed
42
img = torch.randn(1, 3, 400, 600)
zk's avatar
zk committed
43
44

# 导出原始ONNX模型
zk's avatar
zk committed
45
onnx_output_path = "weights_400x600/ground.onnx"
zk's avatar
zk committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
simplified_onnx_path = "weights/ground_simplified1.onnx"


torch.onnx.export(
    model,
    f=onnx_output_path,
    args=(img, input_ids, attention_mask, position_ids, token_type_ids, text_token_mask),
    input_names=["img", "input_ids", "attention_mask", "position_ids", "token_type_ids", "text_token_mask"],
    output_names=["logits", "boxes"],
    dynamic_axes=None, # 静态维度导出
    opset_version=17,
    verbose=False  # 关闭详细日志,如需调试可改为True
    # do_constant_folding=True  # 常量折叠优化,提升简化效果
)
print(f"ONNX模型已成功导出到: {onnx_output_path}")

# # 使用onnxsim简化模型
# print(f"开始简化ONNX模型: {onnx_output_path}")
# try:
#     # 加载原始ONNX模型
#     onnx_model = onnx.load(onnx_output_path)
    
#     # 简化模型(enable_fuse_bn=True 融合批归一化层,更彻底的简化)
#     simplified_model, check = simplify(
#         onnx_model,
#         skip_fuse_bn=True,
#         skip_constant_folding=True,
#         dynamic_input_shape=False, 
#         input_shapes={  # 指定输入形状,确保简化准确
#             "img": (1, 3, 800, 1200),
#             "input_ids": tuple(input_ids.shape),
#             "attention_mask": tuple(attention_mask.shape),
#             "position_ids": tuple(position_ids.shape),
#             "token_type_ids": tuple(token_type_ids.shape),
#             "text_token_mask": tuple(text_token_mask.shape)
#         }
#     )
    
#     # 验证简化后的模型
#     assert check, "简化后的ONNX模型验证失败!"
    
#     # 保存简化后的模型
#     onnx.save(simplified_model, simplified_onnx_path)
#     print(f"ONNX模型简化完成,已保存至: {simplified_onnx_path}")
    
# except Exception as e:
#     print(f"ONNX简化过程出错: {e}")
#     print("将使用原始未简化的ONNX模型")