qwen3vl.py 33.4 KB
Newer Older
hejianlin's avatar
hejianlin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import ctypes
from typing import List, Sequence

from tqdm import tqdm

from libinfinicore_infer import (
    Qwen3vlModel,
    Qwen3vlMetaCStruct,
    TextMetaCStruct,
    VisMetaCStruct,
    Qwen3vlWeightsCStruct,
    Qwen3vlCacheCStruct,
    DataType,
    DeviceType,
)
from infer_task import InferTask, KVCache

from ctypes import POINTER, c_float, c_int, c_uint, c_uint16, c_void_p, byref, c_bool
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import math
import torch
import transformers
PanZezhong's avatar
PanZezhong committed
28

hejianlin's avatar
hejianlin committed
29
30
31
32
33
34
torch.set_default_device("cpu")


class Qwen3vlLangWeightsNaming:
    def input_embd(self):
        return "model.language_model.embed_tokens.weight"
PanZezhong's avatar
PanZezhong committed
35

hejianlin's avatar
hejianlin committed
36
37
    def output_embd(self):
        return "model.language_model.embed_tokens.weight"
PanZezhong's avatar
PanZezhong committed
38

hejianlin's avatar
hejianlin committed
39
40
    def output_norm(self):
        return "model.language_model.norm.weight"
PanZezhong's avatar
PanZezhong committed
41

hejianlin's avatar
hejianlin committed
42
43
    def attn_norm(self, i):
        return f"model.language_model.layers.{i}.input_layernorm.weight"
PanZezhong's avatar
PanZezhong committed
44

hejianlin's avatar
hejianlin committed
45
46
    def attn_q_proj(self, i):
        return f"model.language_model.layers.{i}.self_attn.q_proj.weight"
PanZezhong's avatar
PanZezhong committed
47

hejianlin's avatar
hejianlin committed
48
49
    def attn_q_norm(self, i):
        return f"model.language_model.layers.{i}.self_attn.q_norm.weight"
PanZezhong's avatar
PanZezhong committed
50

hejianlin's avatar
hejianlin committed
51
52
    def attn_k_proj(self, i):
        return f"model.language_model.layers.{i}.self_attn.k_proj.weight"
PanZezhong's avatar
PanZezhong committed
53

hejianlin's avatar
hejianlin committed
54
55
    def attn_k_norm(self, i):
        return f"model.language_model.layers.{i}.self_attn.k_norm.weight"
PanZezhong's avatar
PanZezhong committed
56

hejianlin's avatar
hejianlin committed
57
58
    def attn_o_proj(self, i):
        return f"model.language_model.layers.{i}.self_attn.o_proj.weight"
PanZezhong's avatar
PanZezhong committed
59

hejianlin's avatar
hejianlin committed
60
61
    def attn_v_proj(self, i):
        return f"model.language_model.layers.{i}.self_attn.v_proj.weight"
PanZezhong's avatar
PanZezhong committed
62

hejianlin's avatar
hejianlin committed
63
64
    def mlp_norm(self, i):
        return f"model.language_model.layers.{i}.post_attention_layernorm.weight"
PanZezhong's avatar
PanZezhong committed
65

hejianlin's avatar
hejianlin committed
66
67
    def mlp_gate(self, i):
        return f"model.language_model.layers.{i}.mlp.gate_proj.weight"
PanZezhong's avatar
PanZezhong committed
68

hejianlin's avatar
hejianlin committed
69
70
    def mlp_down(self, i):
        return f"model.language_model.layers.{i}.mlp.down_proj.weight"
PanZezhong's avatar
PanZezhong committed
71

hejianlin's avatar
hejianlin committed
72
73
    def mlp_up(self, i):
        return f"model.language_model.layers.{i}.mlp.up_proj.weight"
PanZezhong's avatar
PanZezhong committed
74
75


hejianlin's avatar
hejianlin committed
76
77
78
class Qwen3vlVisWeightsNaming:
    def patch_embed_weight(self):
        return "model.visual.patch_embed.proj.weight"
PanZezhong's avatar
PanZezhong committed
79

hejianlin's avatar
hejianlin committed
80
81
    def patch_embed_bias(self):
        return "model.visual.patch_embed.proj.bias"
PanZezhong's avatar
PanZezhong committed
82

hejianlin's avatar
hejianlin committed
83
84
    def pos_embed_weight(self):
        return "model.visual.pos_embed.weight"
PanZezhong's avatar
PanZezhong committed
85
86

    def attn_proj_weight(self, i):
hejianlin's avatar
hejianlin committed
87
        return f"model.visual.blocks.{i}.attn.proj.weight"
PanZezhong's avatar
PanZezhong committed
88
89

    def attn_proj_bias(self, i):
hejianlin's avatar
hejianlin committed
90
        return f"model.visual.blocks.{i}.attn.proj.bias"
PanZezhong's avatar
PanZezhong committed
91
92

    def attn_qkv_weight(self, i):
hejianlin's avatar
hejianlin committed
93
        return f"model.visual.blocks.{i}.attn.qkv.weight"
PanZezhong's avatar
PanZezhong committed
94
95

    def attn_qkv_bias(self, i):
hejianlin's avatar
hejianlin committed
96
        return f"model.visual.blocks.{i}.attn.qkv.bias"
PanZezhong's avatar
PanZezhong committed
97
98

    def mlp_linear_fc1_weight(self, i):
hejianlin's avatar
hejianlin committed
99
        return f"model.visual.blocks.{i}.mlp.linear_fc1.weight"
PanZezhong's avatar
PanZezhong committed
100
101

    def mlp_linear_fc1_bias(self, i):
hejianlin's avatar
hejianlin committed
102
        return f"model.visual.blocks.{i}.mlp.linear_fc1.bias"
PanZezhong's avatar
PanZezhong committed
103
104

    def mlp_linear_fc2_weight(self, i):
hejianlin's avatar
hejianlin committed
105
        return f"model.visual.blocks.{i}.mlp.linear_fc2.weight"
PanZezhong's avatar
PanZezhong committed
106
107

    def mlp_linear_fc2_bias(self, i):
hejianlin's avatar
hejianlin committed
108
        return f"model.visual.blocks.{i}.mlp.linear_fc2.bias"
PanZezhong's avatar
PanZezhong committed
109
110

    def norm1_weight(self, i):
hejianlin's avatar
hejianlin committed
111
        return f"model.visual.blocks.{i}.norm1.weight"
PanZezhong's avatar
PanZezhong committed
112
113

    def norm1_bias(self, i):
hejianlin's avatar
hejianlin committed
114
        return f"model.visual.blocks.{i}.norm1.bias"
PanZezhong's avatar
PanZezhong committed
115
116

    def norm2_weight(self, i):
hejianlin's avatar
hejianlin committed
117
        return f"model.visual.blocks.{i}.norm2.weight"
PanZezhong's avatar
PanZezhong committed
118
119

    def norm2_bias(self, i):
hejianlin's avatar
hejianlin committed
120
        return f"model.visual.blocks.{i}.norm2.bias"
PanZezhong's avatar
PanZezhong committed
121
122

    def deepstack_merger_linear_fc1_weight(self, i):
hejianlin's avatar
hejianlin committed
123
        return f"model.visual.deepstack_merger_list.{i}.linear_fc1.weight"
PanZezhong's avatar
PanZezhong committed
124
125

    def deepstack_merger_linear_fc1_bias(self, i):
hejianlin's avatar
hejianlin committed
126
        return f"model.visual.deepstack_merger_list.{i}.linear_fc1.bias"
PanZezhong's avatar
PanZezhong committed
127
128

    def deepstack_merger_linear_fc2_weight(self, i):
hejianlin's avatar
hejianlin committed
129
        return f"model.visual.deepstack_merger_list.{i}.linear_fc2.weight"
PanZezhong's avatar
PanZezhong committed
130
131

    def deepstack_merger_linear_fc2_bias(self, i):
hejianlin's avatar
hejianlin committed
132
        return f"model.visual.deepstack_merger_list.{i}.linear_fc2.bias"
PanZezhong's avatar
PanZezhong committed
133
134

    def deepstack_merger_norm_weight(self, i):
hejianlin's avatar
hejianlin committed
135
        return f"model.visual.deepstack_merger_list.{i}.norm.weight"
PanZezhong's avatar
PanZezhong committed
136
137

    def deepstack_merger_norm_bias(self, i):
hejianlin's avatar
hejianlin committed
138
        return f"model.visual.deepstack_merger_list.{i}.norm.bias"
PanZezhong's avatar
PanZezhong committed
139

hejianlin's avatar
hejianlin committed
140
141
    def merger_linear_fc1_weight(self):
        return "model.visual.merger.linear_fc1.weight"
PanZezhong's avatar
PanZezhong committed
142

hejianlin's avatar
hejianlin committed
143
144
    def merger_linear_fc1_bias(self):
        return "model.visual.merger.linear_fc1.bias"
PanZezhong's avatar
PanZezhong committed
145

hejianlin's avatar
hejianlin committed
146
147
    def merger_linear_fc2_weight(self):
        return "model.visual.merger.linear_fc2.weight"
PanZezhong's avatar
PanZezhong committed
148

hejianlin's avatar
hejianlin committed
149
150
    def merger_linear_fc2_bias(self):
        return "model.visual.merger.linear_fc2.bias"
PanZezhong's avatar
PanZezhong committed
151

hejianlin's avatar
hejianlin committed
152
153
    def merger_norm_weight(self):
        return "model.visual.merger.norm.weight"
PanZezhong's avatar
PanZezhong committed
154

hejianlin's avatar
hejianlin committed
155
156
    def merger_norm_bias(self):
        return "model.visual.merger.norm.bias"
PanZezhong's avatar
PanZezhong committed
157
158


hejianlin's avatar
hejianlin committed
159
160
class Qwen3vlMeta(Qwen3vlMetaCStruct):
    def __init__(self, config, max_tokens=None):
PanZezhong's avatar
PanZezhong committed
161
        if config["text_config"]["dtype"] == "float16":
hejianlin's avatar
hejianlin committed
162
163
            dt_ = DataType.INFINI_DTYPE_F16
            self.torch_dtype = torch.float16
PanZezhong's avatar
PanZezhong committed
164
        elif config["text_config"]["dtype"] == "float32":
hejianlin's avatar
hejianlin committed
165
166
            dt_ = DataType.INFINI_DTYPE_F32
            self.torch_dtype = torch.float32
PanZezhong's avatar
PanZezhong committed
167
        elif config["text_config"]["dtype"] == "bfloat16":
hejianlin's avatar
hejianlin committed
168
169
170
            dt_ = DataType.INFINI_DTYPE_BF16
            self.torch_dtype = torch.bfloat16
        else:
PanZezhong's avatar
PanZezhong committed
171
172
173
            raise ValueError(
                f"Unsupported text dtype: {config['text_config']['dtype']}"
            )
hejianlin's avatar
hejianlin committed
174
175

        super().__init__(
PanZezhong's avatar
PanZezhong committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
            dtype=dt_,
            image_token_id=config["image_token_id"],
            video_token_id=config["video_token_id"],
            vision_end_token_id=config["vision_end_token_id"],
            vision_start_token_id=config["vision_start_token_id"],
            text_meta=TextMetaCStruct(
                bos_token_id=config["text_config"]["bos_token_id"],
                eos_token_id=config["text_config"]["eos_token_id"],
                head_dim=config["text_config"]["head_dim"],
                hidden_size=config["text_config"]["hidden_size"],
                initializer_range=config["text_config"]["initializer_range"],
                intermediate_size=config["text_config"]["intermediate_size"],
                max_tokens=(
                    config["text_config"]["max_position_embeddings"]
                    if max_tokens is None
                    else max_tokens
                ),
                num_attention_heads=config["text_config"]["num_attention_heads"],
                num_hidden_layers=config["text_config"]["num_hidden_layers"],
                num_key_value_heads=config["text_config"]["num_key_value_heads"],
                rms_norm_eps=config["text_config"]["rms_norm_eps"],
                mrope_section=(ctypes.c_ulong * 3)(
                    *config["text_config"]["rope_scaling"]["mrope_section"]
                ),
                rope_theta=config["text_config"]["rope_theta"],
                vocab_size=config["text_config"]["vocab_size"],
            ),
            vis_meta=VisMetaCStruct(
                depth=config["vision_config"]["depth"],
                deepstack_visual_indexes=(ctypes.c_ulong * 3)(
                    *config["vision_config"]["deepstack_visual_indexes"]
                ),
                hidden_size=config["vision_config"]["hidden_size"],
                in_channels=config["vision_config"]["in_channels"],
                initializer_range=config["vision_config"]["initializer_range"],
                intermediate_size=config["vision_config"]["intermediate_size"],
                num_heads=config["vision_config"]["num_heads"],
                num_position_embeddings=config["vision_config"][
                    "num_position_embeddings"
                ],
                out_hidden_size=config["vision_config"]["out_hidden_size"],
                patch_size=config["vision_config"]["patch_size"],
                spatial_merge_size=config["vision_config"]["spatial_merge_size"],
                temporal_patch_size=config["vision_config"]["temporal_patch_size"],
hejianlin's avatar
hejianlin committed
220
221
222
            ),
        )

PanZezhong's avatar
PanZezhong committed
223

hejianlin's avatar
hejianlin committed
224
225
226
227
228
def load_specific_tensor(model_dir, tensor_name):
    """
    Load a specific tensor from a safetensors model.
    Supports both sharded models (with index.json) and single file models.
    """
PanZezhong's avatar
PanZezhong committed
229

hejianlin's avatar
hejianlin committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    # Try to load from individual .safetensors files
    safetensors_files = [f for f in os.listdir(model_dir) if f.endswith(".safetensors")]
    if not safetensors_files:
        raise FileNotFoundError(f"No .safetensors files found in {model_dir}")

    # Try to find the tensor in each file
    for filename in safetensors_files:
        tensor_file = os.path.join(model_dir, filename)
        try:
            with safetensors.safe_open(tensor_file, framework="pt", device="cpu") as f:
                if tensor_name in f.keys():
                    tensor = f.get_tensor(tensor_name)
                    return tensor
        except Exception:
            continue
PanZezhong's avatar
PanZezhong committed
245

hejianlin's avatar
hejianlin committed
246
247
248
    # If we reach here, tensor was not found in any file
    raise KeyError(f"{tensor_name} not found in any .safetensors files")

PanZezhong's avatar
PanZezhong committed
249

hejianlin's avatar
hejianlin committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def load_Qwen3vl_weights(
    meta: Qwen3vlMeta,
    weights,
    model_path: str,
    ndev: int,
):
    # torch load weights, and reshape for qkv_proj / mlp_gate_up stack, attn / mlp parallel
    # weight loader function load from specific offset according to idev, and transpose
    model_instance = Qwen3vlModel()
    weight_loader = model_instance.create_weight_loader()
    vis_names = Qwen3vlVisWeightsNaming()
    lang_names = Qwen3vlLangWeightsNaming()

    nkvh = meta.text_meta.num_key_value_heads
    nh = meta.text_meta.num_attention_heads
    dh = meta.text_meta.head_dim
    d = meta.text_meta.hidden_size
    di = meta.text_meta.intermediate_size

    assert nh % nkvh == 0
    assert nh % ndev == 0
    assert nkvh % ndev == 0
    assert di % ndev == 0

    # -------------------------------
    # Language_model weights
    # -------------------------------
PanZezhong's avatar
PanZezhong committed
277
278
279
    input_embd = load_specific_tensor(model_path, lang_names.input_embd()).to(
        meta.torch_dtype
    )
hejianlin's avatar
hejianlin committed
280
281
282
    weight_loader.contents.lang_loader.load_input_embd(weights, input_embd.data_ptr())
    del input_embd

PanZezhong's avatar
PanZezhong committed
283
284
285
    output_norm = load_specific_tensor(model_path, lang_names.output_norm()).to(
        meta.torch_dtype
    )
hejianlin's avatar
hejianlin committed
286
287
288
    weight_loader.contents.lang_loader.load_output_norm(weights, output_norm.data_ptr())
    del output_norm

PanZezhong's avatar
PanZezhong committed
289
290
291
    output_embd = load_specific_tensor(model_path, lang_names.output_embd()).to(
        meta.torch_dtype
    )
hejianlin's avatar
hejianlin committed
292
293
294
295
    weight_loader.contents.lang_loader.load_output_embd(weights, output_embd.data_ptr())
    del output_embd

    for i in range(meta.text_meta.num_hidden_layers):
PanZezhong's avatar
PanZezhong committed
296
297
298
299
300
301
        attn_norm = load_specific_tensor(model_path, lang_names.attn_norm(i)).to(
            meta.torch_dtype
        )
        weight_loader.contents.lang_loader.load_attn_norm(
            weights, attn_norm.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
302
303
304
305
306
        del attn_norm

        attn_q_proj = load_specific_tensor(model_path, lang_names.attn_q_proj(i))
        attn_k_proj = load_specific_tensor(model_path, lang_names.attn_k_proj(i))
        attn_v_proj = load_specific_tensor(model_path, lang_names.attn_v_proj(i))
PanZezhong's avatar
PanZezhong committed
307
308
309
310

        _Q = attn_q_proj.reshape(nh, dh, d)
        _K = attn_k_proj.reshape(nkvh, dh, d)
        _V = attn_v_proj.reshape(nkvh, dh, d)
hejianlin's avatar
hejianlin committed
311
312
313
314
315
316
317
318
319
320

        qkv_proj = []
        _nh = nh // ndev
        _nkvh = nkvh // ndev
        for _idev in range(ndev):
            qkv_proj.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :])
            qkv_proj.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :])
            qkv_proj.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :])
        attn_qkv_proj = torch.cat(qkv_proj, dim=0).to(meta.torch_dtype).contiguous()

PanZezhong's avatar
PanZezhong committed
321
322
323
        weight_loader.contents.lang_loader.load_attn_qkv_proj(
            weights, attn_qkv_proj.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
324
325
        del attn_qkv_proj

PanZezhong's avatar
PanZezhong committed
326
327
328
329
330
331
        attn_q_norm = load_specific_tensor(model_path, lang_names.attn_q_norm(i)).to(
            meta.torch_dtype
        )
        weight_loader.contents.lang_loader.load_attn_q_norm(
            weights, attn_q_norm.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
332
333
        del attn_q_norm

PanZezhong's avatar
PanZezhong committed
334
335
336
337
338
339
        attn_k_norm = load_specific_tensor(model_path, lang_names.attn_k_norm(i)).to(
            meta.torch_dtype
        )
        weight_loader.contents.lang_loader.load_attn_k_norm(
            weights, attn_k_norm.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
340
341
342
        del attn_k_norm

        attn_o_proj = load_specific_tensor(model_path, lang_names.attn_o_proj(i))
PanZezhong's avatar
PanZezhong committed
343
344
345
346
347
348
349
350
351
        attn_o_proj = (
            attn_o_proj.to(meta.torch_dtype)
            .reshape([d, ndev, nh // ndev * dh])
            .transpose(0, 1)
            .contiguous()
        )
        weight_loader.contents.lang_loader.load_attn_o_proj(
            weights, attn_o_proj.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
352
353
        del attn_o_proj

PanZezhong's avatar
PanZezhong committed
354
355
356
357
358
359
        mlp_norm = load_specific_tensor(model_path, lang_names.mlp_norm(i)).to(
            meta.torch_dtype
        )
        weight_loader.contents.lang_loader.load_mlp_norm(
            weights, mlp_norm.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
360
361
362
363
        del mlp_norm

        mlp_gate = load_specific_tensor(model_path, lang_names.mlp_gate(i))
        mlp_up = load_specific_tensor(model_path, lang_names.mlp_up(i))
PanZezhong's avatar
PanZezhong committed
364

hejianlin's avatar
hejianlin committed
365
366
367
368
369
370
371
372
373
        gate_up = []
        _di = di // ndev
        for _idev in range(ndev):
            _start = _idev * _di
            _end = (_idev + 1) * _di
            gate_up.append(mlp_gate[_start:_end, :])
            gate_up.append(mlp_up[_start:_end, :])
        mlp_gate_up = torch.cat(gate_up, dim=0).to(meta.torch_dtype).contiguous()

PanZezhong's avatar
PanZezhong committed
374
375
376
        weight_loader.contents.lang_loader.load_mlp_gate_up(
            weights, mlp_gate_up.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
377
378
379
        del mlp_gate_up

        mlp_down = load_specific_tensor(model_path, lang_names.mlp_down(i))
PanZezhong's avatar
PanZezhong committed
380
381
382
383
384
385
386
387
388
        mlp_down = (
            mlp_down.to(meta.torch_dtype)
            .reshape([d, ndev, di // ndev])
            .transpose(0, 1)
            .contiguous()
        )
        weight_loader.contents.lang_loader.load_mlp_down(
            weights, mlp_down.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
389
390
391
392
393
        del mlp_down

    # -------------------------------
    # Vision head weights
    # -------------------------------
PanZezhong's avatar
PanZezhong committed
394
395
396
397
398
399
    patch_embed_weight = load_specific_tensor(
        model_path, vis_names.patch_embed_weight()
    ).to(meta.torch_dtype)
    weight_loader.contents.vis_loader.load_patch_embed_weight(
        weights, patch_embed_weight.data_ptr()
    )
hejianlin's avatar
hejianlin committed
400
401
    del patch_embed_weight

PanZezhong's avatar
PanZezhong committed
402
403
404
405
406
407
    patch_embed_bias = load_specific_tensor(
        model_path, vis_names.patch_embed_bias()
    ).to(meta.torch_dtype)
    weight_loader.contents.vis_loader.load_patch_embed_bias(
        weights, patch_embed_bias.data_ptr()
    )
hejianlin's avatar
hejianlin committed
408
409
    del patch_embed_bias

PanZezhong's avatar
PanZezhong committed
410
411
412
413
414
415
    pos_embed_weight = load_specific_tensor(
        model_path, vis_names.pos_embed_weight()
    ).to(meta.torch_dtype)
    weight_loader.contents.vis_loader.load_pos_embed_weight(
        weights, pos_embed_weight.data_ptr()
    )
hejianlin's avatar
hejianlin committed
416
417
418
    del pos_embed_weight

    for i in range(meta.vis_meta.depth):
PanZezhong's avatar
PanZezhong committed
419
420
421
422
423
424
        attn_proj_weight = load_specific_tensor(
            model_path, vis_names.attn_proj_weight(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_attn_proj_weight(
            weights, attn_proj_weight.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
425
426
        del attn_proj_weight

PanZezhong's avatar
PanZezhong committed
427
428
429
430
431
432
        attn_proj_bias = load_specific_tensor(
            model_path, vis_names.attn_proj_bias(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_attn_proj_bias(
            weights, attn_proj_bias.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
433
434
        del attn_proj_bias

PanZezhong's avatar
PanZezhong committed
435
436
437
438
439
440
        attn_qkv_weight = load_specific_tensor(
            model_path, vis_names.attn_qkv_weight(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_attn_qkv_weight(
            weights, attn_qkv_weight.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
441
442
        del attn_qkv_weight

PanZezhong's avatar
PanZezhong committed
443
444
445
446
447
448
        attn_qkv_bias = load_specific_tensor(model_path, vis_names.attn_qkv_bias(i)).to(
            meta.torch_dtype
        )
        weight_loader.contents.vis_loader.load_attn_qkv_bias(
            weights, attn_qkv_bias.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
449
450
        del attn_qkv_bias

PanZezhong's avatar
PanZezhong committed
451
452
453
454
455
456
        mlp_linear_fc1_weight = load_specific_tensor(
            model_path, vis_names.mlp_linear_fc1_weight(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_mlp_linear_fc1_weight(
            weights, mlp_linear_fc1_weight.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
457
458
        del mlp_linear_fc1_weight

PanZezhong's avatar
PanZezhong committed
459
460
461
462
463
464
        mlp_linear_fc1_bias = load_specific_tensor(
            model_path, vis_names.mlp_linear_fc1_bias(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_mlp_linear_fc1_bias(
            weights, mlp_linear_fc1_bias.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
465
466
        del mlp_linear_fc1_bias

PanZezhong's avatar
PanZezhong committed
467
468
469
470
471
472
        mlp_linear_fc2_weight = load_specific_tensor(
            model_path, vis_names.mlp_linear_fc2_weight(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_mlp_linear_fc2_weight(
            weights, mlp_linear_fc2_weight.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
473
474
        del mlp_linear_fc2_weight

PanZezhong's avatar
PanZezhong committed
475
476
477
478
479
480
        mlp_linear_fc2_bias = load_specific_tensor(
            model_path, vis_names.mlp_linear_fc2_bias(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_mlp_linear_fc2_bias(
            weights, mlp_linear_fc2_bias.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
481
482
        del mlp_linear_fc2_bias

PanZezhong's avatar
PanZezhong committed
483
484
485
486
487
488
        norm1_weight = load_specific_tensor(model_path, vis_names.norm1_weight(i)).to(
            meta.torch_dtype
        )
        weight_loader.contents.vis_loader.load_norm1_weight(
            weights, norm1_weight.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
489
490
        del norm1_weight

PanZezhong's avatar
PanZezhong committed
491
492
493
494
495
496
        norm1_bias = load_specific_tensor(model_path, vis_names.norm1_bias(i)).to(
            meta.torch_dtype
        )
        weight_loader.contents.vis_loader.load_norm1_bias(
            weights, norm1_bias.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
497
498
        del norm1_bias

PanZezhong's avatar
PanZezhong committed
499
500
501
502
503
504
        norm2_weight = load_specific_tensor(model_path, vis_names.norm2_weight(i)).to(
            meta.torch_dtype
        )
        weight_loader.contents.vis_loader.load_norm2_weight(
            weights, norm2_weight.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
505
506
        del norm2_weight

PanZezhong's avatar
PanZezhong committed
507
508
509
510
511
512
        norm2_bias = load_specific_tensor(model_path, vis_names.norm2_bias(i)).to(
            meta.torch_dtype
        )
        weight_loader.contents.vis_loader.load_norm2_bias(
            weights, norm2_bias.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
513
514
515
        del norm2_bias

    for i in range(len(meta.vis_meta.deepstack_visual_indexes)):
PanZezhong's avatar
PanZezhong committed
516
517
518
519
520
521
        deepstack_merger_linear_fc1_weight = load_specific_tensor(
            model_path, vis_names.deepstack_merger_linear_fc1_weight(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc1_weight(
            weights, deepstack_merger_linear_fc1_weight.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
522
523
        del deepstack_merger_linear_fc1_weight

PanZezhong's avatar
PanZezhong committed
524
525
526
527
528
529
        deepstack_merger_linear_fc1_bias = load_specific_tensor(
            model_path, vis_names.deepstack_merger_linear_fc1_bias(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc1_bias(
            weights, deepstack_merger_linear_fc1_bias.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
530
531
        del deepstack_merger_linear_fc1_bias

PanZezhong's avatar
PanZezhong committed
532
533
534
535
536
537
        deepstack_merger_linear_fc2_weight = load_specific_tensor(
            model_path, vis_names.deepstack_merger_linear_fc2_weight(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc2_weight(
            weights, deepstack_merger_linear_fc2_weight.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
538
539
        del deepstack_merger_linear_fc2_weight

PanZezhong's avatar
PanZezhong committed
540
541
542
543
544
545
        deepstack_merger_linear_fc2_bias = load_specific_tensor(
            model_path, vis_names.deepstack_merger_linear_fc2_bias(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_deepstack_merger_linear_fc2_bias(
            weights, deepstack_merger_linear_fc2_bias.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
546
547
        del deepstack_merger_linear_fc2_bias

PanZezhong's avatar
PanZezhong committed
548
549
550
551
552
553
        deepstack_merger_norm_weight = load_specific_tensor(
            model_path, vis_names.deepstack_merger_norm_weight(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_deepstack_merger_norm_weight(
            weights, deepstack_merger_norm_weight.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
554
555
        del deepstack_merger_norm_weight

PanZezhong's avatar
PanZezhong committed
556
557
558
559
560
561
        deepstack_merger_norm_bias = load_specific_tensor(
            model_path, vis_names.deepstack_merger_norm_bias(i)
        ).to(meta.torch_dtype)
        weight_loader.contents.vis_loader.load_deepstack_merger_norm_bias(
            weights, deepstack_merger_norm_bias.data_ptr(), i
        )
hejianlin's avatar
hejianlin committed
562
        del deepstack_merger_norm_bias
PanZezhong's avatar
PanZezhong committed
563
564
565
566
567
568
569

    merger_linear_fc1_weight = load_specific_tensor(
        model_path, vis_names.merger_linear_fc1_weight()
    ).to(meta.torch_dtype)
    weight_loader.contents.vis_loader.load_merger_linear_fc1_weight(
        weights, merger_linear_fc1_weight.data_ptr()
    )
hejianlin's avatar
hejianlin committed
570
571
    del merger_linear_fc1_weight

PanZezhong's avatar
PanZezhong committed
572
573
574
575
576
577
    merger_linear_fc1_bias = load_specific_tensor(
        model_path, vis_names.merger_linear_fc1_bias()
    ).to(meta.torch_dtype)
    weight_loader.contents.vis_loader.load_merger_linear_fc1_bias(
        weights, merger_linear_fc1_bias.data_ptr()
    )
hejianlin's avatar
hejianlin committed
578
579
    del merger_linear_fc1_bias

PanZezhong's avatar
PanZezhong committed
580
581
582
583
584
585
    merger_linear_fc2_weight = load_specific_tensor(
        model_path, vis_names.merger_linear_fc2_weight()
    ).to(meta.torch_dtype)
    weight_loader.contents.vis_loader.load_merger_linear_fc2_weight(
        weights, merger_linear_fc2_weight.data_ptr()
    )
hejianlin's avatar
hejianlin committed
586
587
    del merger_linear_fc2_weight

PanZezhong's avatar
PanZezhong committed
588
589
590
591
592
593
    merger_linear_fc2_bias = load_specific_tensor(
        model_path, vis_names.merger_linear_fc2_bias()
    ).to(meta.torch_dtype)
    weight_loader.contents.vis_loader.load_merger_linear_fc2_bias(
        weights, merger_linear_fc2_bias.data_ptr()
    )
hejianlin's avatar
hejianlin committed
594
595
    del merger_linear_fc2_bias

PanZezhong's avatar
PanZezhong committed
596
597
598
599
600
601
    merger_norm_weight = load_specific_tensor(
        model_path, vis_names.merger_norm_weight()
    ).to(meta.torch_dtype)
    weight_loader.contents.vis_loader.load_merger_norm_weight(
        weights, merger_norm_weight.data_ptr()
    )
hejianlin's avatar
hejianlin committed
602
603
    del merger_norm_weight

PanZezhong's avatar
PanZezhong committed
604
605
606
607
608
609
    merger_norm_bias = load_specific_tensor(
        model_path, vis_names.merger_norm_bias()
    ).to(meta.torch_dtype)
    weight_loader.contents.vis_loader.load_merger_norm_bias(
        weights, merger_norm_bias.data_ptr()
    )
hejianlin's avatar
hejianlin committed
610
611
612
613
    del merger_norm_bias


class Qwen3vlBatchedTask:
PanZezhong's avatar
PanZezhong committed
614
615
616
617
618
619
620
621
    def __init__(
        self,
        tasks: List[InferTask],
        all_pixel_values=None,
        all_image_grid_thw=None,
        all_pixel_values_videos=None,
        all_video_grid_thw=None,
    ):
hejianlin's avatar
hejianlin committed
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        self.tasks = tasks
        self.nreq = len(tasks)

        # Precompute fields
        token_lists = [t.tokens for t in tasks]
        self.req_lens_list = [len(toks) for toks in token_lists]
        self.req_pos_list = [t.pos for t in tasks]
        self.kv_cache_ptrs = [t.kvcache().data() for t in tasks]
        self.temperaturas_list = [t.temperature for t in tasks]
        self.topks_list = [t.topk for t in tasks]
        self.topps_list = [t.topp for t in tasks]

        # Flatten token lists
        flat_tokens = [tok for toks in token_lists for tok in toks]
        self.ntok = len(flat_tokens)

        # Convert to ctypes arrays in one pass
        self.tokens = (c_uint * self.ntok)(*flat_tokens)
        self.req_lens = (c_uint * self.nreq)(*self.req_lens_list)
        self.req_pos = (c_uint * self.nreq)(*self.req_pos_list)
PanZezhong's avatar
PanZezhong committed
642
        self.kv_caches = (POINTER(Qwen3vlCacheCStruct) * self.nreq)(*self.kv_cache_ptrs)
hejianlin's avatar
hejianlin committed
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list)
        self.topks = (c_uint * self.nreq)(*self.topks_list)
        self.topps = (c_float * self.nreq)(*self.topps_list)

        # initialize visual encoder inputs
        self.pixel_values = None
        self.total_patches = 0
        self.image_grid_thw = None
        self.num_images = 0
        self.pixel_values_videos = None
        self.total_patches_videos = 0
        self.video_grid_thw = None
        self.num_videos = 0
        self.patch_features = 0

        # Prepare visual encoder inputs
PanZezhong's avatar
PanZezhong committed
659
660
661
662
663
664
665
666
667
668
669
670
        # all_pixel_values = [t.inputs['pixel_values'] for t in tasks if 'pixel_values' in t.inputs]
        # all_image_grid_thw = [t.inputs['image_grid_thw'] for t in tasks if 'image_grid_thw' in t.inputs]
        # all_pixel_values_videos = [t.inputs['pixel_values_videos'] for t in tasks if 'pixel_values_videos' in t.inputs]
        # all_video_grid_thw = [t.inputs['video_grid_thw'] for t in tasks if 'video_grid_thw' in t.inputs]

        if all_pixel_values is not None:
            print(all_pixel_values.shape)
            concat_pixel_values = (
                torch.cat(all_pixel_values, dim=0)
                if isinstance(all_pixel_values, list)
                else all_pixel_values
            )  # (total_patches, features)
hejianlin's avatar
hejianlin committed
671
672
            self.total_patches = concat_pixel_values.shape[0]
            self.patch_features = concat_pixel_values.shape[1]
PanZezhong's avatar
PanZezhong committed
673
674
675
676
677
678
679
680
681
682
683
            self.flat_pixels = (
                concat_pixel_values.flatten().to(torch.bfloat16).contiguous()
            )
            self.pixel_values = self.flat_pixels.data_ptr()

        if all_image_grid_thw is not None:
            concat_grid_thw = (
                torch.cat(all_image_grid_thw, dim=0)
                if isinstance(all_image_grid_thw, list)
                else all_image_grid_thw
            )  # (total_images, 3)
hejianlin's avatar
hejianlin committed
684
            self.num_images = concat_grid_thw.shape[0]
PanZezhong's avatar
PanZezhong committed
685
686
687
688
            self.flat_grid = (
                concat_grid_thw.flatten().to(torch.int32).contiguous().tolist()
            )
            self.image_grid_thw = (c_uint * len(self.flat_grid))(*self.flat_grid)
hejianlin's avatar
hejianlin committed
689

PanZezhong's avatar
PanZezhong committed
690
691
692
693
        if all_pixel_values_videos is not None:
            concat_pixel_values_videos = torch.cat(
                all_pixel_values_videos, dim=0
            )  # (total_patches_videos, features)
hejianlin's avatar
hejianlin committed
694
695
696
            self.total_patches_videos = concat_pixel_values_videos.shape[0]
            self.patch_features_videos = concat_pixel_values_videos.shape[1]
            print(self.patch_features_videos, flush=True)
PanZezhong's avatar
PanZezhong committed
697
698
699
            self.flat_pixels_videos = (
                concat_pixel_values_videos.flatten().to(torch.bfloat16).contiguous()
            )
hejianlin's avatar
hejianlin committed
700
701
            self.pixel_values_videos = self.flat_pixels_videos.ctypes.data_as(c_void_p)

PanZezhong's avatar
PanZezhong committed
702
703
704
705
        if all_video_grid_thw is not None:
            concat_grid_thw_videos = torch.cat(
                all_video_grid_thw, dim=0
            )  # (total_videos, 3)
hejianlin's avatar
hejianlin committed
706
            self.num_videos = concat_grid_thw_videos.shape[0]
PanZezhong's avatar
PanZezhong committed
707
708
709
710
711
712
            flat_grid_videos = (
                concat_grid_thw_videos.flatten().to(torch.int32).contiguous()
            )
            self.video_grid_thw = (c_uint * len(flat_grid_videos))(
                *flat_grid_videos.tolist()
            )
hejianlin's avatar
hejianlin committed
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735

    def input_args(self):
        return (
            self.tokens,
            self.ntok,
            self.pixel_values,
            self.total_patches,
            self.image_grid_thw,
            self.num_images,
            self.pixel_values_videos,
            self.total_patches_videos,
            self.video_grid_thw,
            self.num_videos,
            self.patch_features,
            self.req_lens,
            self.nreq,
            self.req_pos,
            self.kv_caches,
            self.temperaturas,
            self.topks,
            self.topps,
        )

PanZezhong's avatar
PanZezhong committed
736

hejianlin's avatar
hejianlin committed
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
# 需要处理 visual encoder的cache 和 image video输入
class Qwen3vlForCauslLM:
    def __init__(
        self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None
    ):
        with open(os.path.join(model_dir_path, "config.json"), "r") as f:
            config = json.load(f)
            self.config = config
        eos_token_id = self.config["text_config"]["eos_token_id"]
        self.eos_token_id = (
            [eos_token_id] if type(eos_token_id) == int else eos_token_id
        )

        print(model_dir_path)

        if "qwen3_vl" == config["model_type"]:
PanZezhong's avatar
PanZezhong committed
753
            self.meta = Qwen3vlMeta(config, max_tokens=max_tokens)
hejianlin's avatar
hejianlin committed
754
755
756
757
758
759
760
761
762
763
764
            self.processor = transformers.AutoProcessor.from_pretrained(model_dir_path)
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
        else:
            raise ValueError("Unsupported model architecture")

        print(f"Creating model on {ndev} devices...")
        load_start_time = time.time()
        dev_ids = (c_int * ndev)(*[i for i in range(ndev)])

        self.model_instance = Qwen3vlModel()
        weights = self.model_instance.create_weights(
PanZezhong's avatar
PanZezhong committed
765
            byref(self.meta), device, ndev, dev_ids, c_bool(True)
hejianlin's avatar
hejianlin committed
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
        )
        print("Loading weights...")
        # Load weights from host
        load_Qwen3vl_weights(self.meta, weights, model_dir_path, ndev)
        # Create model instance
        self.model_ptr = self.model_instance.create_model(
            byref(self.meta),
            weights,
        )
        load_end_time = time.time()
        print(f"Time used: {load_end_time - load_start_time:.3f}s")

    def max_context_len(self):
        return self.meta.text_meta.max_tokens

    def create_kv_cache(self):
        return self.model_instance.create_cache(self.model_ptr)

    def drop_kv_cache(self, kv_cache):
        self.model_instance.drop_cache(self.model_ptr, kv_cache)

PanZezhong's avatar
PanZezhong committed
787
788
789
790
791
792
793
794
    def batch_infer_one_round(
        self,
        tasks: List[InferTask],
        all_pixel_values=None,
        all_image_grid_thw=None,
        all_pixel_values_videos=None,
        all_video_grid_thw=None,
    ):
hejianlin's avatar
hejianlin committed
795
        output = (c_uint * len(tasks))()
PanZezhong's avatar
PanZezhong committed
796
797
798
799
800
801
802
        batch_inputs = Qwen3vlBatchedTask(
            tasks,
            all_pixel_values,
            all_image_grid_thw,
            all_pixel_values_videos,
            all_video_grid_thw,
        )
hejianlin's avatar
hejianlin committed
803
804
805
806
807
808
809
        self.model_instance.infer_batch(
            self.model_ptr,
            *(batch_inputs.input_args()),
            output,
        )
        return list(output)

PanZezhong's avatar
PanZezhong committed
810
811
812
    def generate(
        self, input_content, max_steps=0, topp_=1.0, topk_=1, temperature_=1.0
    ):
hejianlin's avatar
hejianlin committed
813
        inputs = self.processor.apply_chat_template(
PanZezhong's avatar
PanZezhong committed
814
            conversation=[{"role": "user", "content": input_content}],
hejianlin's avatar
hejianlin committed
815
816
817
818
819
            tokenize=True,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt",
        )
PanZezhong's avatar
PanZezhong committed
820
821
822
823
824
825
826
827
828
829
830
        tokens = inputs["input_ids"][0].tolist()
        pixel_values = inputs["pixel_values"] if "pixel_values" in inputs else None
        image_grid_thw = (
            inputs["image_grid_thw"] if "image_grid_thw" in inputs else None
        )
        pixel_values_videos = (
            inputs["pixel_values_videos"] if "pixel_values_videos" in inputs else None
        )
        video_grid_thw = (
            inputs["video_grid_thw"] if "video_grid_thw" in inputs else None
        )
hejianlin's avatar
hejianlin committed
831
832
833

        infer_task = InferTask(
            0,
PanZezhong's avatar
PanZezhong committed
834
            tokens,
hejianlin's avatar
hejianlin committed
835
836
837
838
839
840
841
            self.max_context_len(),
            temperature_,
            topk_,
            topp_,
            self.eos_token_id,
        )
        infer_task.bind_kvcache(KVCache(self))
PanZezhong's avatar
PanZezhong committed
842
        print(input_content)
hejianlin's avatar
hejianlin committed
843
844
845
846
        steps = 0
        total_time = 0
        output_content = ""

PanZezhong's avatar
PanZezhong committed
847
        # print(inputs['input_ids'][0].tolist(), flush=True)
hejianlin's avatar
hejianlin committed
848

PanZezhong's avatar
PanZezhong committed
849
        for step_i in range(max_steps if max_steps > 0 else self.max_context_len()):
hejianlin's avatar
hejianlin committed
850
            start_time = time.time()
PanZezhong's avatar
PanZezhong committed
851
852
853
854
855
856
857
858
            output_tokens = self.batch_infer_one_round(
                [infer_task],
                pixel_values,
                image_grid_thw,
                pixel_values_videos,
                video_grid_thw,
            )
            # print(output_tokens)
hejianlin's avatar
hejianlin committed
859
860
861
862
863
            end_time = time.time()
            steps += 1
            output_str = self.tokenizer.decode(output_tokens[0])
            output_content += output_str
            print(output_str, end="", flush=True)
PanZezhong's avatar
PanZezhong committed
864
865
866
867
            pixel_values = None
            image_grid_thw = None
            pixel_values_videos = None
            video_grid_thw = None
hejianlin's avatar
hejianlin committed
868
869
870
871
872
873
874
875
876
            if output_tokens[0] in self.eos_token_id:
                break
            infer_task.next(output_tokens[0])

            if step_i > 0:
                total_time += end_time - start_time

        print("\n")
        avg_time = total_time * 1000 / steps if steps > 0 else -1
PanZezhong's avatar
PanZezhong committed
877
        # print(output_content, flush=True)
hejianlin's avatar
hejianlin committed
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
        print(f"Time per step: {avg_time:.3f}ms")

        infer_task._kv_cache.drop(self)
        return output_content, avg_time

    def destroy_model_instance(self):
        self.model_instance.destroy_model(self.model_ptr)
        print("Model destroyed")


def test():
    if len(sys.argv) < 3:
        print(
            "Usage: python qwen3vl.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
        )
        sys.exit(1)
    model_path = sys.argv[2]
    device_type = DeviceType.DEVICE_TYPE_CPU
    if sys.argv[1] == "--cpu":
        device_type = DeviceType.DEVICE_TYPE_CPU
    elif sys.argv[1] == "--nvidia":
        device_type = DeviceType.DEVICE_TYPE_NVIDIA
    elif sys.argv[1] == "--cambricon":
        device_type = DeviceType.DEVICE_TYPE_CAMBRICON
    elif sys.argv[1] == "--ascend":
        device_type = DeviceType.DEVICE_TYPE_ASCEND
    elif sys.argv[1] == "--metax":
        device_type = DeviceType.DEVICE_TYPE_METAX
    elif sys.argv[1] == "--moore":
        device_type = DeviceType.DEVICE_TYPE_MOORE
    elif sys.argv[1] == "--iluvatar":
        device_type = DeviceType.DEVICE_TYPE_ILUVATAR
    else:
        print(
            "Usage: python qwen3vl.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
        )
        sys.exit(1)
    ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
PanZezhong's avatar
PanZezhong committed
916
917
918
919
920

    img_url = None
    if len(sys.argv) > 4:
        img_url = sys.argv[4]

hejianlin's avatar
hejianlin committed
921
    model = Qwen3vlForCauslLM(model_path, device_type, ndev, max_tokens=1024)
PanZezhong's avatar
PanZezhong committed
922
923
924
925
926
927
928
929
930
    input_content = (
        [
            {"type": "text", "text": "Describe this image."},
            {"type": "image", "url": img_url},
        ]
        if img_url is not None
        else [{"type": "text", "text": "山东最高的山是?"}]
    )
    model.generate(input_content)
hejianlin's avatar
hejianlin committed
931
932
933
934
    model.destroy_model_instance()


if __name__ == "__main__":
PanZezhong's avatar
PanZezhong committed
935
    test()