utils.py 13.6 KB
Newer Older
1
2
"""Common utilities."""

Lianmin Zheng's avatar
Lianmin Zheng committed
3
import base64
4
import multiprocessing
5
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
6
7
8
9
import os
import random
import socket
import time
10
from importlib.metadata import PackageNotFoundError, version
Lianmin Zheng's avatar
Lianmin Zheng committed
11
from io import BytesIO
Lianmin Zheng's avatar
Lianmin Zheng committed
12
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
13
14
15

import numpy as np
import requests
16
import rpyc
Lianmin Zheng's avatar
Lianmin Zheng committed
17
import torch
18
import triton
19
from rpyc.utils.server import ThreadedServer
20
from fastapi.responses import JSONResponse
21
from packaging import version as pkg_version
Lianmin Zheng's avatar
Lianmin Zheng committed
22
from starlette.middleware.base import BaseHTTPMiddleware
23

Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
26
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
27

Liangsheng Yin's avatar
Liangsheng Yin committed
28
29
show_time_cost = False
time_infos = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31


Liangsheng Yin's avatar
Liangsheng Yin committed
32
33
34
35
def enable_show_time_cost():
    global show_time_cost
    show_time_cost = True

Lianmin Zheng's avatar
Lianmin Zheng committed
36

Liangsheng Yin's avatar
Liangsheng Yin committed
37
38
39
40
41
42
class TimeInfo:
    def __init__(self, name, interval=0.1, color=0, indent=0):
        self.name = name
        self.interval = interval
        self.color = color
        self.indent = indent
Lianmin Zheng's avatar
Lianmin Zheng committed
43

Liangsheng Yin's avatar
Liangsheng Yin committed
44
45
        self.acc_time = 0
        self.last_acc_time = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
46

Liangsheng Yin's avatar
Liangsheng Yin committed
47
48
49
50
51
    def check(self):
        if self.acc_time - self.last_acc_time > self.interval:
            self.last_acc_time = self.acc_time
            return True
        return False
Lianmin Zheng's avatar
Lianmin Zheng committed
52

Liangsheng Yin's avatar
Liangsheng Yin committed
53
54
55
56
    def pretty_print(self):
        print(f"\x1b[{self.color}m", end="")
        print("-" * self.indent * 2, end="")
        print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
Lianmin Zheng's avatar
Lianmin Zheng committed
57
58


Liangsheng Yin's avatar
Liangsheng Yin committed
59
60
61
62
def mark_start(name, interval=0.1, color=0, indent=0):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
63
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
64
65
66
    if time_infos.get(name, None) is None:
        time_infos[name] = TimeInfo(name, interval, color, indent)
    time_infos[name].acc_time -= time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68


Liangsheng Yin's avatar
Liangsheng Yin committed
69
70
71
72
def mark_end(name):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
73
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
74
75
76
    time_infos[name].acc_time += time.time()
    if time_infos[name].check():
        time_infos[name].pretty_print()
Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97


def calculate_time(show=False, min_cost_ms=0.0):
    def wrapper(func):
        def inner_func(*args, **kwargs):
            torch.cuda.synchronize()
            if show:
                start_time = time.time()
            result = func(*args, **kwargs)
            torch.cuda.synchronize()
            if show:
                cost_time = (time.time() - start_time) * 1000
                if cost_time > min_cost_ms:
                    print(f"Function {func.__name__} took {cost_time} ms to run.")
            return result

        return inner_func

    return wrapper


98
def get_available_gpu_memory(gpu_id, distributed=False):
Lianmin Zheng's avatar
Lianmin Zheng committed
99
100
101
102
103
104
105
106
107
108
109
110
111
    """
    Get available memory for cuda:gpu_id device.
    When distributed is True, the available memory is the minimum available memory of all GPUs.
    """
    num_gpus = torch.cuda.device_count()
    assert gpu_id < num_gpus

    if torch.cuda.current_device() != gpu_id:
        print(
            f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
            "which may cause useless memory allocation for torch CUDA context.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
112
    torch.cuda.empty_cache()
Lianmin Zheng's avatar
Lianmin Zheng committed
113
114
115
116
117
118
119
120
121
122
123
124
    free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)

    if distributed:
        tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
            torch.device("cuda", gpu_id)
        )
        torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
        free_gpu_memory = tensor.item()

    return free_gpu_memory / (1 << 30)


Lianmin Zheng's avatar
Lianmin Zheng committed
125
def set_random_seed(seed: int) -> None:
126
    """Set the random seed for all libraries."""
Lianmin Zheng's avatar
Lianmin Zheng committed
127
    random.seed(seed)
128
    np.random.seed(seed)
Lianmin Zheng's avatar
Lianmin Zheng committed
129
130
131
132
133
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


134
def is_port_available(port):
135
    """Return whether a port is available."""
136
137
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
138
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
139
            s.bind(("", port))
140
            s.listen(1)
141
142
143
144
145
            return True
        except socket.error:
            return False


Lianmin Zheng's avatar
Lianmin Zheng committed
146
def allocate_init_ports(
Lianmin Zheng's avatar
Lianmin Zheng committed
147
148
149
    port: Optional[int] = None,
    additional_ports: Optional[List[int]] = None,
    tp_size: int = 1,
150
    dp_size: int = 1,
Lianmin Zheng's avatar
Lianmin Zheng committed
151
):
152
    """Allocate ports for all connections."""
153
154
155
156
157
158
159
160
    if additional_ports:
        ret_ports = [port] + additional_ports
    else:
        ret_ports = [port]

    ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
    cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000

161
162
163
    # HTTP + Tokenizer + Controller + Detokenizer + dp_size * (nccl + tp_size)
    num_ports_needed = 4 + dp_size * (1 + tp_size)
    while len(ret_ports) < num_ports_needed:
164
165
166
167
        if cur_port not in ret_ports and is_port_available(cur_port):
            ret_ports.append(cur_port)
        cur_port += 1

168
    if port is not None and ret_ports[0] != port:
169
170
171
        logger.warn(
            f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
172

173
    return ret_ports[0], ret_ports[1:num_ports_needed]
174

Lianmin Zheng's avatar
Lianmin Zheng committed
175

Lianmin Zheng's avatar
Lianmin Zheng committed
176
def get_int_token_logit_bias(tokenizer, vocab_size):
177
    """Get the logit bias for integer-only tokens."""
178
179
    # a bug when model's vocab size > tokenizer.vocab_size
    vocab_size = tokenizer.vocab_size
Lianmin Zheng's avatar
Lianmin Zheng committed
180
181
    logit_bias = np.zeros(vocab_size, dtype=np.float32)
    for t_id in range(vocab_size):
182
        ss = tokenizer.decode([t_id]).strip()
Lianmin Zheng's avatar
Lianmin Zheng committed
183
184
185
186
187
188
189
190
        if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
            logit_bias[t_id] = -1e5

    return logit_bias


def wrap_kernel_launcher(kernel):
    """A faster launcher for triton kernels."""
191
192
    if int(triton.__version__.split(".")[0]) >= 3:
        return None
Lianmin Zheng's avatar
Lianmin Zheng committed
193

194
195
    gpu_id = torch.cuda.current_device()
    kernels = kernel.cache[gpu_id].values()
Lianmin Zheng's avatar
Lianmin Zheng committed
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    kernel = next(iter(kernels))

    # Different trition versions use different low-level names
    if hasattr(kernel, "cu_function"):
        kfunction = kernel.cu_function
    else:
        kfunction = kernel.function

    if hasattr(kernel, "c_wrapper"):
        run = kernel.c_wrapper
    else:
        run = kernel.run

    add_cluster_dim = True

    def ret_func(grid, num_warps, *args):
        nonlocal add_cluster_dim

        try:
            if add_cluster_dim:
                run(
                    grid[0],
                    grid[1],
                    grid[2],
                    num_warps,
                    1,
                    1,
                    1,
                    1,
                    kernel.shared,
                    0,
                    kfunction,
                    None,
                    None,
                    kernel,
                    *args,
                )
            else:
                run(
                    grid[0],
                    grid[1],
                    grid[2],
                    num_warps,
                    kernel.shared,
                    0,
                    kfunction,
                    None,
                    None,
                    kernel,
                    *args,
                )
        except TypeError:
            add_cluster_dim = not add_cluster_dim
            ret_func(grid, num_warps, *args)

    return ret_func


def is_multimodal_model(model):
    from sglang.srt.model_config import ModelConfig

Yuanhan Zhang's avatar
Yuanhan Zhang committed
257
258
259
260
    if isinstance(model, str):
        model = model.lower()
        return "llava" in model or "yi-vl" in model or "llava-next" in model

Lianmin Zheng's avatar
Lianmin Zheng committed
261
    if isinstance(model, ModelConfig):
Christopher Chou's avatar
Christopher Chou committed
262
        model_path = model.path.lower()
Liangsheng Yin's avatar
Liangsheng Yin committed
263
264
265
        return (
            "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346

    raise ValueError("unrecognized type")


def decode_video_base64(video_base64):
    from PIL import Image

    # Decode the base64 string
    video_bytes = base64.b64decode(video_base64)

    # Placeholder for the start indices of each PNG image
    img_starts = []

    frame_format = "PNG"  # str(os.getenv('FRAME_FORMAT', "JPEG"))

    assert frame_format in [
        "PNG",
        "JPEG",
    ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"

    if frame_format == "PNG":
        # Find each PNG start signature to isolate images
        i = 0
        while i < len(video_bytes) - 7:  # Adjusted for the length of the PNG signature
            # Check if we found the start of a PNG file
            if (
                video_bytes[i] == 0x89
                and video_bytes[i + 1] == 0x50
                and video_bytes[i + 2] == 0x4E
                and video_bytes[i + 3] == 0x47
                and video_bytes[i + 4] == 0x0D
                and video_bytes[i + 5] == 0x0A
                and video_bytes[i + 6] == 0x1A
                and video_bytes[i + 7] == 0x0A
            ):
                img_starts.append(i)
                i += 8  # Skip the PNG signature
            else:
                i += 1
    else:
        # Find each JPEG start (0xFFD8) to isolate images
        i = 0
        while (
            i < len(video_bytes) - 1
        ):  # Adjusted for the length of the JPEG SOI signature
            # Check if we found the start of a JPEG file
            if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
                img_starts.append(i)
                # Move to the next byte to continue searching for the next image start
                i += 2
            else:
                i += 1

    frames = []
    for start_idx in img_starts:
        # Assuming each image is back-to-back, the end of one image is the start of another
        # The last image goes until the end of the byte string
        end_idx = (
            img_starts[img_starts.index(start_idx) + 1]
            if img_starts.index(start_idx) + 1 < len(img_starts)
            else len(video_bytes)
        )
        img_bytes = video_bytes[start_idx:end_idx]

        # Convert bytes to a PIL Image
        img = Image.open(BytesIO(img_bytes))

        # Convert PIL Image to a NumPy array
        frame = np.array(img)

        # Append the frame to the list of frames
        frames.append(frame)

    # Ensure there's at least one frame to avoid errors with np.stack
    if frames:
        return np.stack(frames, axis=0), img.size
    else:
        return np.array([]), (
            0,
            0,
        )  # Return an empty array and size tuple if no frames were found
Lianmin Zheng's avatar
Lianmin Zheng committed
347
348
349
350
351


def load_image(image_file):
    from PIL import Image

Yuanhan Zhang's avatar
Yuanhan Zhang committed
352
    image = image_size = None
Lianmin Zheng's avatar
Lianmin Zheng committed
353
354
355
356
357
358
359
360

    if image_file.startswith("http://") or image_file.startswith("https://"):
        timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
        response = requests.get(image_file, timeout=timeout)
        image = Image.open(BytesIO(response.content))
    elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
        image = Image.open(image_file)
    elif image_file.startswith("data:"):
361
        image_file = image_file.split(",")[1]
Lianmin Zheng's avatar
Lianmin Zheng committed
362
        image = Image.open(BytesIO(base64.b64decode(image_file)))
Yuanhan Zhang's avatar
Yuanhan Zhang committed
363
364
365
    elif image_file.startswith("video:"):
        image_file = image_file.replace("video:", "")
        image, image_size = decode_video_base64(image_file)
Lianmin Zheng's avatar
Lianmin Zheng committed
366
367
368
    else:
        image = Image.open(BytesIO(base64.b64decode(image_file)))

Yuanhan Zhang's avatar
Yuanhan Zhang committed
369
    return image, image_size
370
371


372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
def init_rpyc_service(service: rpyc.Service, port: int):
    t = ThreadedServer(
        service=service,
        port=port,
        protocol_config={
            "allow_public_attrs": True,
            "allow_pickle": True,
            "sync_request_timeout": 3600
        },
    )
    t.logger.setLevel(logging.WARN)
    t.start()


def connect_to_rpyc_service(port, host="localhost"):
    time.sleep(1)

    repeat_count = 0
    while repeat_count < 20:
        try:
            con = rpyc.connect(
                host,
                port,
                config={
                    "allow_public_attrs": True,
                    "allow_pickle": True,
                    "sync_request_timeout": 3600
                },
            )
            break
        except ConnectionRefusedError:
            time.sleep(1)
        repeat_count += 1
    if repeat_count == 20:
        raise RuntimeError("init rpc env error!")

    return con.root


def start_rpyc_process(service: rpyc.Service, port: int):
    # Return the proxy and the process
    proc = multiprocessing.Process(target=init_rpyc_service, args=(service, port))
    proc.start()
    proxy = connect_to_rpyc_service(port)
    assert proc.is_alive()
    return proxy, proc


def suppress_other_loggers():
    from vllm.logger import logger as vllm_default_logger

    vllm_default_logger.setLevel(logging.WARN)
    logging.getLogger("vllm.config").setLevel(logging.ERROR)
Lianmin Zheng's avatar
Lianmin Zheng committed
425
426
427
    logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(logging.WARN)
    logging.getLogger("vllm.selector").setLevel(logging.WARN)
    logging.getLogger("vllm.utils").setLevel(logging.WARN)
428
429


430
431
432
433
434
435
436
437
438
def assert_pkg_version(pkg: str, min_version: str):
    try:
        installed_version = version(pkg)
        if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
            raise Exception(
                f"{pkg} is installed with version {installed_version} which "
                f"is less than the minimum required version {min_version}"
            )
    except PackageNotFoundError:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
439
440
441
        raise Exception(
            f"{pkg} with minimum required version {min_version} is not installed"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460


API_KEY_HEADER_NAME = "X-API-Key"


class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, api_key: str):
        super().__init__(app)
        self.api_key = api_key

    async def dispatch(self, request, call_next):
        # extract API key from the request headers
        api_key_header = request.headers.get(API_KEY_HEADER_NAME)
        if not api_key_header or api_key_header != self.api_key:
            return JSONResponse(
                status_code=403,
                content={"detail": "Invalid API Key"},
            )
        response = await call_next(request)
461
        return response