offline_batch_inference_vlm.py 1.95 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
Usage:
python offline_batch_inference_vlm.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template=qwen2-vl
"""

import argparse
import dataclasses

from transformers import AutoProcessor

import sglang as sgl
from sglang.srt.openai_api.adapter import v1_chat_generate_request
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs


def main(
    server_args: ServerArgs,
):
    # Create an LLM.
    vlm = sgl.Engine(**dataclasses.asdict(server_args))

    # prepare prompts.
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "What’s in this image?"},
                {
                    "type": "image_url",
                    "image_url": {
                        "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true",
                    },
                },
            ],
        }
    ]
    chat_request = ChatCompletionRequest(
        messages=messages,
        model=server_args.model_path,
        temperature=0.8,
        top_p=0.95,
    )
    gen_request, _ = v1_chat_generate_request(
        [chat_request],
        vlm.tokenizer_manager,
    )

    outputs = vlm.generate(
        input_ids=gen_request.input_ids,
        image_data=gen_request.image_data,
        sampling_params=gen_request.sampling_params,
    )

    print("===============================")
    print(f"Prompt: {messages[0]['content'][0]['text']}")
    print(f"Generated text: {outputs['text']}")


# The __main__ condition is necessary here because we use "spawn" to create subprocesses
# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    ServerArgs.add_cli_args(parser)
    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    main(server_args)