utils.py 8.99 KB
Newer Older
1
2
"""Common utilities."""

Lianmin Zheng's avatar
Lianmin Zheng committed
3
4
5
6
7
8
9
import base64
import os
import random
import socket
import sys
import time
import traceback
10
from importlib.metadata import PackageNotFoundError, version
Lianmin Zheng's avatar
Lianmin Zheng committed
11
from io import BytesIO
Lianmin Zheng's avatar
Lianmin Zheng committed
12
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
13
14

import numpy as np
Lianmin Zheng's avatar
Lianmin Zheng committed
15
import pydantic
Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
import requests
import torch
18
from fastapi.responses import JSONResponse
19
from packaging import version as pkg_version
Lianmin Zheng's avatar
Lianmin Zheng committed
20
21
from pydantic import BaseModel
from starlette.middleware.base import BaseHTTPMiddleware
Lianmin Zheng's avatar
Lianmin Zheng committed
22

Liangsheng Yin's avatar
Liangsheng Yin committed
23
24
show_time_cost = False
time_infos = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
25
26


Liangsheng Yin's avatar
Liangsheng Yin committed
27
28
29
30
def enable_show_time_cost():
    global show_time_cost
    show_time_cost = True

Lianmin Zheng's avatar
Lianmin Zheng committed
31

Liangsheng Yin's avatar
Liangsheng Yin committed
32
33
34
35
36
37
class TimeInfo:
    def __init__(self, name, interval=0.1, color=0, indent=0):
        self.name = name
        self.interval = interval
        self.color = color
        self.indent = indent
Lianmin Zheng's avatar
Lianmin Zheng committed
38

Liangsheng Yin's avatar
Liangsheng Yin committed
39
40
        self.acc_time = 0
        self.last_acc_time = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
41

Liangsheng Yin's avatar
Liangsheng Yin committed
42
43
44
45
46
    def check(self):
        if self.acc_time - self.last_acc_time > self.interval:
            self.last_acc_time = self.acc_time
            return True
        return False
Lianmin Zheng's avatar
Lianmin Zheng committed
47

Liangsheng Yin's avatar
Liangsheng Yin committed
48
49
50
51
    def pretty_print(self):
        print(f"\x1b[{self.color}m", end="")
        print("-" * self.indent * 2, end="")
        print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
Lianmin Zheng's avatar
Lianmin Zheng committed
52
53


Liangsheng Yin's avatar
Liangsheng Yin committed
54
55
56
57
def mark_start(name, interval=0.1, color=0, indent=0):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
58
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
59
60
61
    if time_infos.get(name, None) is None:
        time_infos[name] = TimeInfo(name, interval, color, indent)
    time_infos[name].acc_time -= time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
62
63


Liangsheng Yin's avatar
Liangsheng Yin committed
64
65
66
67
def mark_end(name):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
68
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
69
70
71
    time_infos[name].acc_time += time.time()
    if time_infos[name].check():
        time_infos[name].pretty_print()
Lianmin Zheng's avatar
Lianmin Zheng committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118


def calculate_time(show=False, min_cost_ms=0.0):
    def wrapper(func):
        def inner_func(*args, **kwargs):
            torch.cuda.synchronize()
            if show:
                start_time = time.time()
            result = func(*args, **kwargs)
            torch.cuda.synchronize()
            if show:
                cost_time = (time.time() - start_time) * 1000
                if cost_time > min_cost_ms:
                    print(f"Function {func.__name__} took {cost_time} ms to run.")
            return result

        return inner_func

    return wrapper


def set_random_seed(seed: int) -> None:
    random.seed(seed)

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def alloc_usable_network_port(num, used_list=()):
    port_list = []
    for port in range(10000, 65536):
        if port in used_list:
            continue

        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            try:
                s.bind(("", port))
                port_list.append(port)
            except socket.error:
                pass

            if len(port_list) == num:
                return port_list
    return None


119
120
121
def check_port(port):
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
122
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
123
124
125
126
127
128
            s.bind(("", port))
            return True
        except socket.error:
            return False


Lianmin Zheng's avatar
Lianmin Zheng committed
129
def allocate_init_ports(
Lianmin Zheng's avatar
Lianmin Zheng committed
130
131
132
133
    port: Optional[int] = None,
    additional_ports: Optional[List[int]] = None,
    tp_size: int = 1,
):
134
135
    port = 30000 if port is None else port
    additional_ports = [] if additional_ports is None else additional_ports
Lianmin Zheng's avatar
Lianmin Zheng committed
136
137
138
    additional_ports = (
        [additional_ports] if isinstance(additional_ports, int) else additional_ports
    )
139
140
141
    # first check on server port
    if not check_port(port):
        new_port = alloc_usable_network_port(1, used_list=[port])[0]
142
        print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
143
144
145
146
147
148
149
150
151
152
153
154
155
        port = new_port

    # then we check on additional ports
    additional_unique_ports = set(additional_ports) - {port}
    # filter out ports that are already in use
    can_use_ports = [port for port in additional_unique_ports if check_port(port)]

    num_specified_ports = len(can_use_ports)
    if num_specified_ports < 4 + tp_size:
        addtional_can_use_ports = alloc_usable_network_port(
            num=4 + tp_size - num_specified_ports, used_list=can_use_ports + [port]
        )
        can_use_ports.extend(addtional_can_use_ports)
Lianmin Zheng's avatar
Lianmin Zheng committed
156
157

    additional_ports = can_use_ports[: 4 + tp_size]
158
159
    return port, additional_ports

Lianmin Zheng's avatar
Lianmin Zheng committed
160

Lianmin Zheng's avatar
Lianmin Zheng committed
161
162
163
164
165
166
167
def get_exception_traceback():
    etype, value, tb = sys.exc_info()
    err_str = "".join(traceback.format_exception(etype, value, tb))
    return err_str


def get_int_token_logit_bias(tokenizer, vocab_size):
168
169
    # a bug when model's vocab size > tokenizer.vocab_size
    vocab_size = tokenizer.vocab_size
Lianmin Zheng's avatar
Lianmin Zheng committed
170
171
    logit_bias = np.zeros(vocab_size, dtype=np.float32)
    for t_id in range(vocab_size):
172
        ss = tokenizer.decode([t_id]).strip()
Lianmin Zheng's avatar
Lianmin Zheng committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
            logit_bias[t_id] = -1e5

    return logit_bias


def wrap_kernel_launcher(kernel):
    """A faster launcher for triton kernels."""
    import torch.distributed as dist

    if dist.is_initialized():
        rank = dist.get_rank()
    else:
        rank = 0

    kernels = kernel.cache[rank].values()
    kernel = next(iter(kernels))

    # Different trition versions use different low-level names
    if hasattr(kernel, "cu_function"):
        kfunction = kernel.cu_function
    else:
        kfunction = kernel.function

    if hasattr(kernel, "c_wrapper"):
        run = kernel.c_wrapper
    else:
        run = kernel.run

    add_cluster_dim = True

    def ret_func(grid, num_warps, *args):
        nonlocal add_cluster_dim

        try:
            if add_cluster_dim:
                run(
                    grid[0],
                    grid[1],
                    grid[2],
                    num_warps,
                    1,
                    1,
                    1,
                    1,
                    kernel.shared,
                    0,
                    kfunction,
                    None,
                    None,
                    kernel,
                    *args,
                )
            else:
                run(
                    grid[0],
                    grid[1],
                    grid[2],
                    num_warps,
                    kernel.shared,
                    0,
                    kfunction,
                    None,
                    None,
                    kernel,
                    *args,
                )
        except TypeError:
            add_cluster_dim = not add_cluster_dim
            ret_func(grid, num_warps, *args)

    return ret_func


def is_multimodal_model(model):
    if isinstance(model, str):
249
        return "llava" in model or "yi-vl" in model
Lianmin Zheng's avatar
Lianmin Zheng committed
250
251
252
    from sglang.srt.model_config import ModelConfig

    if isinstance(model, ModelConfig):
Christopher Chou's avatar
Christopher Chou committed
253
        model_path = model.path.lower()
254
        return "llava" in model_path or "yi-vl" in model_path
Lianmin Zheng's avatar
Lianmin Zheng committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    raise Exception("unrecognized type")


def load_image(image_file):
    from PIL import Image

    image = None

    if image_file.startswith("http://") or image_file.startswith("https://"):
        timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
        response = requests.get(image_file, timeout=timeout)
        image = Image.open(BytesIO(response.content))
    elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
        image = Image.open(image_file)
    elif image_file.startswith("data:"):
270
        image_file = image_file.split(",")[1]
Lianmin Zheng's avatar
Lianmin Zheng committed
271
272
273
274
        image = Image.open(BytesIO(base64.b64decode(image_file)))
    else:
        image = Image.open(BytesIO(base64.b64decode(image_file)))

275
    return image
276
277
278
279
280
281
282
283
284
285
286
287


def assert_pkg_version(pkg: str, min_version: str):
    try:
        installed_version = version(pkg)
        if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
            raise Exception(
                f"{pkg} is installed with version {installed_version} which "
                f"is less than the minimum required version {min_version}"
            )
    except PackageNotFoundError:
        raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
Lianmin Zheng's avatar
Lianmin Zheng committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308


API_KEY_HEADER_NAME = "X-API-Key"


class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, api_key: str):
        super().__init__(app)
        self.api_key = api_key

    async def dispatch(self, request, call_next):
        # extract API key from the request headers
        api_key_header = request.headers.get(API_KEY_HEADER_NAME)
        if not api_key_header or api_key_header != self.api_key:
            return JSONResponse(
                status_code=403,
                content={"detail": "Invalid API Key"},
            )
        response = await call_next(request)
        return response

309

Lianmin Zheng's avatar
Lianmin Zheng committed
310
311
312
313
314
315
316
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1 = int(pydantic.VERSION.split(".")[0]) == 1


def jsonify_pydantic_model(obj: BaseModel):
    if IS_PYDANTIC_1:
        return obj.json(ensure_ascii=False)
317
    return obj.model_dump_json()