export_onnx_batchsize.py 6 KB
Newer Older
zk's avatar
zk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import onnx
from onnxsim import simplify

from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict

config_file = './groundingdino/config/GroundingDINO_SwinB_cfg.py'
checkpoint_path = './weights/groundingdino_swinb_cogcoor.pth'

def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
    args = SLConfig.fromfile(model_config_path)
    args.device = "cuda" if not cpu_only else "cpu"

    # modified config
    args.use_checkpoint = False
    args.use_transformer_ckpt = False

    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    _ = model.eval()
    return model

# 加载模型
model = load_model(config_file, checkpoint_path, cpu_only=True)

# ===================== 核心修改:batch_size=4 =====================
BATCH_SIZE = 8
# 正式推理时使用的提示词,以及相关的mask
caption = "car ."

# 1. 文本输入扩展到batch_size=4
# 重复caption BATCH_SIZE次,构建批量文本输入
input_ids = model.tokenizer([caption]*BATCH_SIZE, return_tensors="pt", padding="longest")["input_ids"]
seq_len = input_ids.shape[1]  # 获取序列长度(适配不同caption)

# 2. 扩展position_ids到batch_size=4
position_ids = torch.tensor([[0, 0, 1, 0]]).repeat(BATCH_SIZE, 1)
# 确保position_ids长度匹配seq_len(截断/补零)
if position_ids.shape[1] < seq_len:
    pad_len = seq_len - position_ids.shape[1]
    position_ids = torch.cat([position_ids, torch.zeros(BATCH_SIZE, pad_len, dtype=torch.long)], dim=1)
else:
    position_ids = position_ids[:, :seq_len]

# 3. 扩展token_type_ids到batch_size=4
token_type_ids = torch.tensor([[0, 0, 0, 0]]).repeat(BATCH_SIZE, 1)
if token_type_ids.shape[1] < seq_len:
    pad_len = seq_len - token_type_ids.shape[1]
    token_type_ids = torch.cat([token_type_ids, torch.zeros(BATCH_SIZE, pad_len, dtype=torch.long)], dim=1)
else:
    token_type_ids = token_type_ids[:, :seq_len]

# 4. 扩展attention_mask到batch_size=4
attention_mask = torch.tensor([[True, True, True, True]]).repeat(BATCH_SIZE, 1)
if attention_mask.shape[1] < seq_len:
    pad_len = seq_len - attention_mask.shape[1]
    attention_mask = torch.cat([attention_mask, torch.ones(BATCH_SIZE, pad_len, dtype=torch.bool)], dim=1)
else:
    attention_mask = attention_mask[:, :seq_len]

# 5. 扩展text_token_mask到batch_size=4
text_token_mask = torch.tensor([[[True, False, False, False],
                                 [False,  True,  True,  False],
                                 [False,  True,  True,  False],
                                 [False,  False, False, True]]]).repeat(BATCH_SIZE, 1, 1)
# 调整mask维度匹配seq_len
if text_token_mask.shape[1] < seq_len:
    pad_len = seq_len - text_token_mask.shape[1]
    # 补全mask的行和列
    pad_row = torch.zeros(BATCH_SIZE, pad_len, text_token_mask.shape[2], dtype=torch.bool)
    text_token_mask = torch.cat([text_token_mask, pad_row], dim=1)
    pad_col = torch.zeros(BATCH_SIZE, seq_len, pad_len, dtype=torch.bool)
    text_token_mask = torch.cat([text_token_mask, pad_col], dim=2)
else:
    text_token_mask = text_token_mask[:, :seq_len, :seq_len]

# 6. 扩展图像输入到batch_size=4 (1,3,800,1200) -> (4,3,800,1200)
img = torch.randn(BATCH_SIZE, 3, 800, 1200)

# 打印输入形状,验证batch_size=8
print("="*50)
print("输入形状验证(batch_size=8):")
print(f"img: {img.shape}")
print(f"input_ids: {input_ids.shape}")
print(f"attention_mask: {attention_mask.shape}")
print(f"position_ids: {position_ids.shape}")
print(f"token_type_ids: {token_type_ids.shape}")
print(f"text_token_mask: {text_token_mask.shape}")
print("="*50)

# onnx模型可以支持动态输入,在转换engine时建议注销
dynamic_axes = {
    "input_ids": {0: "batch_size", 1: "seq_len"},
    "attention_mask": {0: "batch_size", 1: "seq_len"},
    "position_ids": {0: "batch_size", 1: "seq_len"},
    "token_type_ids": {0: "batch_size", 1: "seq_len"},
    "text_token_mask": {0: "batch_size", 1: "seq_len", 2: "seq_len"},
    "img": {0: "batch_size", 2: "height", 3: "width"},
    "logits": {0: "batch_size"},
    "boxes": {0: "batch_size"}
}

# 导出原始ONNX模型
onnx_output_path = "weights/ground_bs8.onnx"
simplified_onnx_path = "weights/ground_simplified_bs8.onnx"

torch.onnx.export(
    model,
    f=onnx_output_path,
    args=(img, input_ids, attention_mask, position_ids, token_type_ids, text_token_mask),
    input_names=["img", "input_ids", "attention_mask", "position_ids", "token_type_ids", "text_token_mask"],
    output_names=["logits", "boxes"],
    # dynamic_axes=dynamic_axes,  # 转换engine时建议注释
    opset_version=17,
    verbose=False,
    do_constant_folding=True  # 常量折叠优化,提升简化效果
)

# # 使用onnxsim简化模型
# print(f"\n开始简化ONNX模型: {onnx_output_path}")
# try:
#     # 加载原始ONNX模型
#     onnx_model = onnx.load(onnx_output_path)
    
#     # 简化模型(enable_fuse_bn=True 融合批归一化层,更彻底的简化)
#     simplified_model, check = simplify(
#         onnx_model,
#         dynamic_input_shape=False,  # 因为固定了batch_size和分辨率,设为False
#         input_shapes={  # 指定batch_size=4的输入形状
#             "img": (BATCH_SIZE, 3, 800, 1200),
#             "input_ids": tuple(input_ids.shape),
#             "attention_mask": tuple(attention_mask.shape),
#             "position_ids": tuple(position_ids.shape),
#             "token_type_ids": tuple(token_type_ids.shape),
#             "text_token_mask": tuple(text_token_mask.shape)
#         }
#     )
    
#     # 验证简化后的模型
#     assert check, "简化后的ONNX模型验证失败!"
    
#     # 保存简化后的模型
#     onnx.save(simplified_model, simplified_onnx_path)
#     print(f"ONNX模型简化完成,已保存至: {simplified_onnx_path}")
    
# except Exception as e:
#     print(f"ONNX简化过程出错: {e}")
#     print("将使用原始未简化的ONNX模型")