utils.py 11.5 KB
Newer Older
1
2
"""Common utilities."""

Lianmin Zheng's avatar
Lianmin Zheng committed
3
import base64
4
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
5
6
7
8
import os
import random
import socket
import time
9
from importlib.metadata import PackageNotFoundError, version
Lianmin Zheng's avatar
Lianmin Zheng committed
10
from io import BytesIO
Lianmin Zheng's avatar
Lianmin Zheng committed
11
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
12
13

import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
14
import pydantic
Lianmin Zheng's avatar
Lianmin Zheng committed
15
16
import requests
import torch
17
from fastapi.responses import JSONResponse
18
from packaging import version as pkg_version
Lianmin Zheng's avatar
Lianmin Zheng committed
19
20
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
Lianmin Zheng's avatar
Lianmin Zheng committed
21

22
23
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
24

Liangsheng Yin's avatar
Liangsheng Yin committed
25
26
show_time_cost = False
time_infos = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28


Liangsheng Yin's avatar
Liangsheng Yin committed
29
30
31
32
def enable_show_time_cost():
    global show_time_cost
    show_time_cost = True

Lianmin Zheng's avatar
Lianmin Zheng committed
33

Liangsheng Yin's avatar
Liangsheng Yin committed
34
35
36
37
38
39
class TimeInfo:
    def __init__(self, name, interval=0.1, color=0, indent=0):
        self.name = name
        self.interval = interval
        self.color = color
        self.indent = indent
Lianmin Zheng's avatar
Lianmin Zheng committed
40

Liangsheng Yin's avatar
Liangsheng Yin committed
41
42
        self.acc_time = 0
        self.last_acc_time = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
43

Liangsheng Yin's avatar
Liangsheng Yin committed
44
45
46
47
48
    def check(self):
        if self.acc_time - self.last_acc_time > self.interval:
            self.last_acc_time = self.acc_time
            return True
        return False
Lianmin Zheng's avatar
Lianmin Zheng committed
49

Liangsheng Yin's avatar
Liangsheng Yin committed
50
51
52
53
    def pretty_print(self):
        print(f"\x1b[{self.color}m", end="")
        print("-" * self.indent * 2, end="")
        print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55


Liangsheng Yin's avatar
Liangsheng Yin committed
56
57
58
59
def mark_start(name, interval=0.1, color=0, indent=0):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
60
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
61
62
63
    if time_infos.get(name, None) is None:
        time_infos[name] = TimeInfo(name, interval, color, indent)
    time_infos[name].acc_time -= time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
64
65


Liangsheng Yin's avatar
Liangsheng Yin committed
66
67
68
69
def mark_end(name):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
70
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
71
72
73
    time_infos[name].acc_time += time.time()
    if time_infos[name].check():
        time_infos[name].pretty_print()
Lianmin Zheng's avatar
Lianmin Zheng committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94


def calculate_time(show=False, min_cost_ms=0.0):
    def wrapper(func):
        def inner_func(*args, **kwargs):
            torch.cuda.synchronize()
            if show:
                start_time = time.time()
            result = func(*args, **kwargs)
            torch.cuda.synchronize()
            if show:
                cost_time = (time.time() - start_time) * 1000
                if cost_time > min_cost_ms:
                    print(f"Function {func.__name__} took {cost_time} ms to run.")
            return result

        return inner_func

    return wrapper


Lianmin Zheng's avatar
Lianmin Zheng committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
def get_available_gpu_memory(gpu_id, distributed=True):
    """
    Get available memory for cuda:gpu_id device.
    When distributed is True, the available memory is the minimum available memory of all GPUs.
    """
    num_gpus = torch.cuda.device_count()
    assert gpu_id < num_gpus

    if torch.cuda.current_device() != gpu_id:
        print(
            f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
            "which may cause useless memory allocation for torch CUDA context.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
109
    torch.cuda.empty_cache()
Lianmin Zheng's avatar
Lianmin Zheng committed
110
111
112
113
114
115
116
117
118
119
120
121
    free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)

    if distributed:
        tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
            torch.device("cuda", gpu_id)
        )
        torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
        free_gpu_memory = tensor.item()

    return free_gpu_memory / (1 << 30)


Lianmin Zheng's avatar
Lianmin Zheng committed
122
123
124
125
126
127
128
129
def set_random_seed(seed: int) -> None:
    random.seed(seed)

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


130
def is_port_available(port):
131
132
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
133
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
134
            s.bind(("", port))
135
            s.listen(1)
136
137
138
139
140
            return True
        except socket.error:
            return False


Lianmin Zheng's avatar
Lianmin Zheng committed
141
def allocate_init_ports(
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143
144
145
    port: Optional[int] = None,
    additional_ports: Optional[List[int]] = None,
    tp_size: int = 1,
):
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    if additional_ports:
        ret_ports = [port] + additional_ports
    else:
        ret_ports = [port]

    ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
    cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000

    while len(ret_ports) < 5 + tp_size:
        if cur_port not in ret_ports and is_port_available(cur_port):
            ret_ports.append(cur_port)
        cur_port += 1

    if port and ret_ports[0] != port:
160
161
162
        logger.warn(
            f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
163

164
    return ret_ports[0], ret_ports[1:]
165

Lianmin Zheng's avatar
Lianmin Zheng committed
166

Lianmin Zheng's avatar
Lianmin Zheng committed
167
def get_int_token_logit_bias(tokenizer, vocab_size):
168
169
    # a bug when model's vocab size > tokenizer.vocab_size
    vocab_size = tokenizer.vocab_size
Lianmin Zheng's avatar
Lianmin Zheng committed
170
171
    logit_bias = np.zeros(vocab_size, dtype=np.float32)
    for t_id in range(vocab_size):
172
        ss = tokenizer.decode([t_id]).strip()
Lianmin Zheng's avatar
Lianmin Zheng committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
            logit_bias[t_id] = -1e5

    return logit_bias


def wrap_kernel_launcher(kernel):
    """A faster launcher for triton kernels."""
    import torch.distributed as dist

    if dist.is_initialized():
        rank = dist.get_rank()
    else:
        rank = 0

    kernels = kernel.cache[rank].values()
    kernel = next(iter(kernels))

    # Different trition versions use different low-level names
    if hasattr(kernel, "cu_function"):
        kfunction = kernel.cu_function
    else:
        kfunction = kernel.function

    if hasattr(kernel, "c_wrapper"):
        run = kernel.c_wrapper
    else:
        run = kernel.run

    add_cluster_dim = True

    def ret_func(grid, num_warps, *args):
        nonlocal add_cluster_dim

        try:
            if add_cluster_dim:
                run(
                    grid[0],
                    grid[1],
                    grid[2],
                    num_warps,
                    1,
                    1,
                    1,
                    1,
                    kernel.shared,
                    0,
                    kfunction,
                    None,
                    None,
                    kernel,
                    *args,
                )
            else:
                run(
                    grid[0],
                    grid[1],
                    grid[2],
                    num_warps,
                    kernel.shared,
                    0,
                    kfunction,
                    None,
                    None,
                    kernel,
                    *args,
                )
        except TypeError:
            add_cluster_dim = not add_cluster_dim
            ret_func(grid, num_warps, *args)

    return ret_func


def is_multimodal_model(model):
    from sglang.srt.model_config import ModelConfig

Yuanhan Zhang's avatar
Yuanhan Zhang committed
250
251
252
253
    if isinstance(model, str):
        model = model.lower()
        return "llava" in model or "yi-vl" in model or "llava-next" in model

Lianmin Zheng's avatar
Lianmin Zheng committed
254
    if isinstance(model, ModelConfig):
Christopher Chou's avatar
Christopher Chou committed
255
        model_path = model.path.lower()
Liangsheng Yin's avatar
Liangsheng Yin committed
256
257
258
        return (
            "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339

    raise ValueError("unrecognized type")


def decode_video_base64(video_base64):
    from PIL import Image

    # Decode the base64 string
    video_bytes = base64.b64decode(video_base64)

    # Placeholder for the start indices of each PNG image
    img_starts = []

    frame_format = "PNG"  # str(os.getenv('FRAME_FORMAT', "JPEG"))

    assert frame_format in [
        "PNG",
        "JPEG",
    ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"

    if frame_format == "PNG":
        # Find each PNG start signature to isolate images
        i = 0
        while i < len(video_bytes) - 7:  # Adjusted for the length of the PNG signature
            # Check if we found the start of a PNG file
            if (
                video_bytes[i] == 0x89
                and video_bytes[i + 1] == 0x50
                and video_bytes[i + 2] == 0x4E
                and video_bytes[i + 3] == 0x47
                and video_bytes[i + 4] == 0x0D
                and video_bytes[i + 5] == 0x0A
                and video_bytes[i + 6] == 0x1A
                and video_bytes[i + 7] == 0x0A
            ):
                img_starts.append(i)
                i += 8  # Skip the PNG signature
            else:
                i += 1
    else:
        # Find each JPEG start (0xFFD8) to isolate images
        i = 0
        while (
            i < len(video_bytes) - 1
        ):  # Adjusted for the length of the JPEG SOI signature
            # Check if we found the start of a JPEG file
            if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
                img_starts.append(i)
                # Move to the next byte to continue searching for the next image start
                i += 2
            else:
                i += 1

    frames = []
    for start_idx in img_starts:
        # Assuming each image is back-to-back, the end of one image is the start of another
        # The last image goes until the end of the byte string
        end_idx = (
            img_starts[img_starts.index(start_idx) + 1]
            if img_starts.index(start_idx) + 1 < len(img_starts)
            else len(video_bytes)
        )
        img_bytes = video_bytes[start_idx:end_idx]

        # Convert bytes to a PIL Image
        img = Image.open(BytesIO(img_bytes))

        # Convert PIL Image to a NumPy array
        frame = np.array(img)

        # Append the frame to the list of frames
        frames.append(frame)

    # Ensure there's at least one frame to avoid errors with np.stack
    if frames:
        return np.stack(frames, axis=0), img.size
    else:
        return np.array([]), (
            0,
            0,
        )  # Return an empty array and size tuple if no frames were found
Lianmin Zheng's avatar
Lianmin Zheng committed
340
341
342
343
344


def load_image(image_file):
    from PIL import Image

Yuanhan Zhang's avatar
Yuanhan Zhang committed
345
    image = image_size = None
Lianmin Zheng's avatar
Lianmin Zheng committed
346
347
348
349
350
351
352
353

    if image_file.startswith("http://") or image_file.startswith("https://"):
        timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
        response = requests.get(image_file, timeout=timeout)
        image = Image.open(BytesIO(response.content))
    elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
        image = Image.open(image_file)
    elif image_file.startswith("data:"):
354
        image_file = image_file.split(",")[1]
Lianmin Zheng's avatar
Lianmin Zheng committed
355
        image = Image.open(BytesIO(base64.b64decode(image_file)))
Yuanhan Zhang's avatar
Yuanhan Zhang committed
356
357
358
    elif image_file.startswith("video:"):
        image_file = image_file.replace("video:", "")
        image, image_size = decode_video_base64(image_file)
Lianmin Zheng's avatar
Lianmin Zheng committed
359
360
361
    else:
        image = Image.open(BytesIO(base64.b64decode(image_file)))

Yuanhan Zhang's avatar
Yuanhan Zhang committed
362
    return image, image_size
363
364
365
366
367
368
369
370
371
372
373


def assert_pkg_version(pkg: str, min_version: str):
    try:
        installed_version = version(pkg)
        if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
            raise Exception(
                f"{pkg} is installed with version {installed_version} which "
                f"is less than the minimum required version {min_version}"
            )
    except PackageNotFoundError:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
374
375
376
        raise Exception(
            f"{pkg} with minimum required version {min_version} is not installed"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395


API_KEY_HEADER_NAME = "X-API-Key"


class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, api_key: str):
        super().__init__(app)
        self.api_key = api_key

    async def dispatch(self, request, call_next):
        # extract API key from the request headers
        api_key_header = request.headers.get(API_KEY_HEADER_NAME)
        if not api_key_header or api_key_header != self.api_key:
            return JSONResponse(
                status_code=403,
                content={"detail": "Invalid API Key"},
            )
        response = await call_next(request)
396
        return response