"magic_pdf/para/para_split_v2.py" did not exist on "83753cbd774385535a606bc05d0edcdf12d9058a"
llava_onevision_server.py 7.98 KB
Newer Older
1
2
3
"""
Usage:

4
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8
5

Kiv Chen's avatar
Kiv Chen committed
6
python3 llava_onevision_server.py
7
8
"""

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import base64
import io
import os
import sys
import time

import numpy as np
import openai
import requests
from decord import VideoReader, cpu
from PIL import Image

# pip install httpx==0.23.3
# pip install decord
# pip install protobuf==3.20.0


def download_video(url, cache_dir):
    file_path = os.path.join(cache_dir, "jobs.mp4")
    os.makedirs(cache_dir, exist_ok=True)

    response = requests.get(url)
    response.raise_for_status()

    with open(file_path, "wb") as f:
        f.write(response.content)

    print(f"File downloaded and saved to: {file_path}")
    return file_path


def create_openai_client(base_url):
    return openai.Client(api_key="EMPTY", base_url=base_url)


def image_stream_request_test(client):
    print("----------------------Image Stream Request Test----------------------")
    stream_request = client.chat.completions.create(
        model="default",
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
                        },
                    },
                    {
                        "type": "text",
                        "text": "Please describe this image. Please list the benchmarks and the models.",
                    },
                ],
            },
        ],
        temperature=0.7,
        max_tokens=1024,
        stream=True,
    )
    stream_response = ""

    for chunk in stream_request:
        if chunk.choices[0].delta.content is not None:
            content = chunk.choices[0].delta.content
            stream_response += content
            sys.stdout.write(content)
            sys.stdout.flush()

    print("-" * 30)


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def multi_image_stream_request_test(client):
    print(
        "----------------------Multi-Images Stream Request Test----------------------"
    )
    stream_request = client.chat.completions.create(
        model="default",
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
                        },
96
                        "modalities": "multi-images",
97
98
99
100
101
102
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
                        },
103
                        "modalities": "multi-images",
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
                    },
                    {
                        "type": "text",
                        "text": "I have shown you two images. Please describe the two images to me.",
                    },
                ],
            },
        ],
        temperature=0.7,
        max_tokens=1024,
        stream=True,
    )
    stream_response = ""

    for chunk in stream_request:
        if chunk.choices[0].delta.content is not None:
            content = chunk.choices[0].delta.content
            stream_response += content
            sys.stdout.write(content)
            sys.stdout.flush()

    print("-" * 30)


128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def video_stream_request_test(client, video_path):
    print("------------------------Video Stream Request Test----------------------")
    messages = prepare_video_messages(video_path)

    video_request = client.chat.completions.create(
        model="default",
        messages=messages,
        temperature=0,
        max_tokens=1024,
        stream=True,
    )
    print("-" * 30)
    video_response = ""

    for chunk in video_request:
        if chunk.choices[0].delta.content is not None:
            content = chunk.choices[0].delta.content
            video_response += content
            sys.stdout.write(content)
            sys.stdout.flush()
    print("-" * 30)


def image_speed_test(client):
    print("----------------------Image Speed Test----------------------")
153
    start_time = time.perf_counter()
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    request = client.chat.completions.create(
        model="default",
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
                        },
                    },
                    {
                        "type": "text",
                        "text": "Please describe this image. Please list the benchmarks and the models.",
                    },
                ],
            },
        ],
        temperature=0,
        max_tokens=1024,
    )
176
    end_time = time.perf_counter()
177
178
179
180
181
182
183
184
185
186
    response = request.choices[0].message.content
    print(response)
    print("-" * 30)
    print_speed_test_results(request, start_time, end_time)


def video_speed_test(client, video_path):
    print("------------------------Video Speed Test------------------------")
    messages = prepare_video_messages(video_path)

187
    start_time = time.perf_counter()
188
189
190
191
192
193
    video_request = client.chat.completions.create(
        model="default",
        messages=messages,
        temperature=0,
        max_tokens=1024,
    )
194
    end_time = time.perf_counter()
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    video_response = video_request.choices[0].message.content
    print(video_response)
    print("-" * 30)
    print_speed_test_results(video_request, start_time, end_time)


def prepare_video_messages(video_path):
    max_frames_num = 32
    vr = VideoReader(video_path, ctx=cpu(0))
    total_frame_num = len(vr)
    uniform_sampled_frames = np.linspace(
        0, total_frame_num - 1, max_frames_num, dtype=int
    )
    frame_idx = uniform_sampled_frames.tolist()
    frames = vr.get_batch(frame_idx).asnumpy()

    base64_frames = []
    for frame in frames:
        pil_img = Image.fromarray(frame)
        buff = io.BytesIO()
        pil_img.save(buff, format="JPEG")
        base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
        base64_frames.append(base64_str)

    messages = [{"role": "user", "content": []}]

    for base64_frame in base64_frames:
222
223
224
225
226
227
        frame_format = {
            "type": "image_url",
            "image_url": {"url": f"data:image/jpeg;base64,{base64_frame}"},
            "modalities": "video",
        }
        messages[0]["content"].append(frame_format)
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

    prompt = {"type": "text", "text": "Please describe the video in detail."}
    messages[0]["content"].append(prompt)

    return messages


def print_speed_test_results(request, start_time, end_time):
    total_tokens = request.usage.total_tokens
    completion_tokens = request.usage.completion_tokens
    prompt_tokens = request.usage.prompt_tokens

    print(f"Total tokens: {total_tokens}")
    print(f"Completion tokens: {completion_tokens}")
    print(f"Prompt tokens: {prompt_tokens}")
    print(f"Time taken: {end_time - start_time} seconds")
    print(f"Token per second: {total_tokens / (end_time - start_time)}")
    print(f"Completion token per second: {completion_tokens / (end_time - start_time)}")
    print(f"Prompt token per second: {prompt_tokens / (end_time - start_time)}")


def main():
    url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
    cache_dir = os.path.expanduser("~/.cache")
    video_path = download_video(url, cache_dir)

    client = create_openai_client("http://127.0.0.1:30000/v1")

    image_stream_request_test(client)
257
    multi_image_stream_request_test(client)
258
259
260
261
262
263
264
    video_stream_request_test(client, video_path)
    image_speed_test(client)
    video_speed_test(client, video_path)


if __name__ == "__main__":
    main()