"vscode:/vscode.git/clone" did not exist on "619bb6ddda39cada67f75426979b77e7b42bb15e"
utils.py 46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""Common utilities."""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import base64
17
import dataclasses
18
import io
19
import ipaddress
20
import itertools
21
import json
22
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
23
import os
24
import pickle
Lianmin Zheng's avatar
Lianmin Zheng committed
25
import random
Lianmin Zheng's avatar
Lianmin Zheng committed
26
import re
27
import resource
28
29
import shutil
import signal
Lianmin Zheng's avatar
Lianmin Zheng committed
30
import socket
31
import subprocess
Lianmin Zheng's avatar
Lianmin Zheng committed
32
import tempfile
Lianmin Zheng's avatar
Lianmin Zheng committed
33
import time
34
import warnings
35
from functools import lru_cache
36
from importlib.metadata import PackageNotFoundError, version
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from io import BytesIO
38
from multiprocessing.reduction import ForkingPickler
39
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41

import numpy as np
42
import psutil
Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
import requests
import torch
45
import torch.distributed
46
import torch.distributed as dist
47
import triton
48
import zmq
49
from fastapi.responses import ORJSONResponse
50
from packaging import version as pkg_version
Lianmin Zheng's avatar
Lianmin Zheng committed
51
from starlette.routing import Mount
52
from torch import nn
53
from torch.func import functional_call
54
from torch.library import Library
55
from torch.profiler import ProfilerActivity, profile, record_function
56
57
58
59
60
61
from triton.runtime.cache import (
    FileCacheManager,
    default_cache_dir,
    default_dump_dir,
    default_override_dir,
)
62
from uvicorn.config import LOGGING_CONFIG
63

64
65
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
66
67
show_time_cost = False
time_infos = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
68
69


70
def is_hip() -> bool:
71
    """Return whether it is HIP on the AMD ROCm platform."""
72
73
74
    return torch.version.hip is not None


75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def is_cuda():
    return hasattr(torch, "cuda") and torch.cuda.is_available()


def is_cuda_alike():
    return is_cuda() or is_hip()


def is_hpu() -> bool:
    return hasattr(torch, "hpu") and torch.hpu.is_available()


def is_xpu() -> bool:
    return hasattr(torch, "xpu") and torch.xpu.is_available()


91
92
93
94
95
def is_flashinfer_available():
    """
    Check whether flashinfer is available.
    As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
    """
96
    if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
97
        return False
98
    return torch.cuda.is_available() and torch.version.cuda
99
100


101
102
103
104
def is_cuda_available():
    return torch.cuda.is_available() and torch.version.cuda


Liangsheng Yin's avatar
Liangsheng Yin committed
105
106
107
108
def enable_show_time_cost():
    global show_time_cost
    show_time_cost = True

Lianmin Zheng's avatar
Lianmin Zheng committed
109

Liangsheng Yin's avatar
Liangsheng Yin committed
110
111
112
113
114
115
class TimeInfo:
    def __init__(self, name, interval=0.1, color=0, indent=0):
        self.name = name
        self.interval = interval
        self.color = color
        self.indent = indent
Lianmin Zheng's avatar
Lianmin Zheng committed
116

Liangsheng Yin's avatar
Liangsheng Yin committed
117
118
        self.acc_time = 0
        self.last_acc_time = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
119

Liangsheng Yin's avatar
Liangsheng Yin committed
120
121
122
123
124
    def check(self):
        if self.acc_time - self.last_acc_time > self.interval:
            self.last_acc_time = self.acc_time
            return True
        return False
Lianmin Zheng's avatar
Lianmin Zheng committed
125

Liangsheng Yin's avatar
Liangsheng Yin committed
126
127
128
129
    def pretty_print(self):
        print(f"\x1b[{self.color}m", end="")
        print("-" * self.indent * 2, end="")
        print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
Lianmin Zheng's avatar
Lianmin Zheng committed
130
131


Liangsheng Yin's avatar
Liangsheng Yin committed
132
133
134
135
def mark_start(name, interval=0.1, color=0, indent=0):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
136
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
137
138
139
    if time_infos.get(name, None) is None:
        time_infos[name] = TimeInfo(name, interval, color, indent)
    time_infos[name].acc_time -= time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
140
141


Liangsheng Yin's avatar
Liangsheng Yin committed
142
143
144
145
def mark_end(name):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
146
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
147
148
149
    time_infos[name].acc_time += time.time()
    if time_infos[name].check():
        time_infos[name].pretty_print()
Lianmin Zheng's avatar
Lianmin Zheng committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170


def calculate_time(show=False, min_cost_ms=0.0):
    def wrapper(func):
        def inner_func(*args, **kwargs):
            torch.cuda.synchronize()
            if show:
                start_time = time.time()
            result = func(*args, **kwargs)
            torch.cuda.synchronize()
            if show:
                cost_time = (time.time() - start_time) * 1000
                if cost_time > min_cost_ms:
                    print(f"Function {func.__name__} took {cost_time} ms to run.")
            return result

        return inner_func

    return wrapper


171
def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
Lianmin Zheng's avatar
Lianmin Zheng committed
172
173
174
175
    """
    Get available memory for cuda:gpu_id device.
    When distributed is True, the available memory is the minimum available memory of all GPUs.
    """
Zhang, Liangang's avatar
Zhang, Liangang committed
176
177
178
179
180
181
182
183
184
185
    if device == "cuda":
        num_gpus = torch.cuda.device_count()
        assert gpu_id < num_gpus

        if torch.cuda.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
                "which may cause useless memory allocation for torch CUDA context.",
            )

186
187
        if empty_cache:
            torch.cuda.empty_cache()
Zhang, Liangang's avatar
Zhang, Liangang committed
188
189
190
191
192
193
194
195
196
197
198
        free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)

    elif device == "xpu":
        num_gpus = torch.xpu.device_count()
        assert gpu_id < num_gpus

        if torch.xpu.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
                "which may cause useless memory allocation for torch XPU context.",
            )
199
200
201

        if empty_cache:
            torch.xpu.empty_cache()
Zhang, Liangang's avatar
Zhang, Liangang committed
202
203
204
        used_memory = torch.xpu.memory_allocated()
        total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
        free_gpu_memory = total_gpu_memory - used_memory
Lianmin Zheng's avatar
Lianmin Zheng committed
205

206
207
208
209
210
211
212
213
214
215
216
217
    elif device == "hpu":
        num_gpus = torch.hpu.device_count()
        assert gpu_id < num_gpus

        if torch.hpu.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ",
                "which may cause useless memory allocation for torch HPU context.",
            )

        free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()

218
219
220
221
    elif device == "cpu":
        # TODO: rename the variables in the current function to be not GPU specific
        free_gpu_memory = psutil.virtual_memory().available

Lianmin Zheng's avatar
Lianmin Zheng committed
222
223
    if distributed:
        tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
Zhang, Liangang's avatar
Zhang, Liangang committed
224
            torch.device(device, gpu_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
225
226
227
228
229
230
231
        )
        torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
        free_gpu_memory = tensor.item()

    return free_gpu_memory / (1 << 30)


232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def is_pin_memory_available() -> bool:
    return torch.cuda.is_available()


_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
    device = next(module.parameters()).device

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()
    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
    offloaded_parameters = False
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
        cpu_data = torch.empty_strided(
            size=p.data.size(),
            stride=p.data.stride(),
            dtype=p.data.dtype,
            layout=p.data.layout,
            device="cpu",
            pin_memory=pin_memory,
        )
        cpu_data.copy_(p.data)
        p.data = cpu_data
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
        offloaded_parameters = True

    if offloaded_parameters:
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
            output = functional_call(module, device_state, args=args, kwargs=kwargs)
            module.forward = forward
            return output

        module.forward = forward

    return module


class LayerFn(Protocol):

    def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...


def make_layers(
    num_hidden_layers: int,
    layer_fn: LayerFn,
    prefix: str = "",
) -> Tuple[int, int, torch.nn.ModuleList]:
    """Make a list of layers with the given layer function"""
    modules = torch.nn.ModuleList(
        [
            maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
            for idx in range(num_hidden_layers)
        ]
    )
    return modules


Lianmin Zheng's avatar
Lianmin Zheng committed
320
def set_random_seed(seed: int) -> None:
321
    """Set the random seed for all libraries."""
Lianmin Zheng's avatar
Lianmin Zheng committed
322
    random.seed(seed)
323
    np.random.seed(seed)
Lianmin Zheng's avatar
Lianmin Zheng committed
324
325
326
327
328
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


329
def is_port_available(port):
330
    """Return whether a port is available."""
331
332
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
333
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
334
            s.bind(("", port))
335
            s.listen(1)
336
337
338
            return True
        except socket.error:
            return False
TianYu GUO's avatar
TianYu GUO committed
339
340
        except OverflowError:
            return False
341
342


Yuanhan Zhang's avatar
Yuanhan Zhang committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
def decode_video_base64(video_base64):
    from PIL import Image

    # Decode the base64 string
    video_bytes = base64.b64decode(video_base64)

    # Placeholder for the start indices of each PNG image
    img_starts = []

    frame_format = "PNG"  # str(os.getenv('FRAME_FORMAT', "JPEG"))

    assert frame_format in [
        "PNG",
        "JPEG",
    ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"

    if frame_format == "PNG":
        # Find each PNG start signature to isolate images
        i = 0
        while i < len(video_bytes) - 7:  # Adjusted for the length of the PNG signature
            # Check if we found the start of a PNG file
            if (
                video_bytes[i] == 0x89
                and video_bytes[i + 1] == 0x50
                and video_bytes[i + 2] == 0x4E
                and video_bytes[i + 3] == 0x47
                and video_bytes[i + 4] == 0x0D
                and video_bytes[i + 5] == 0x0A
                and video_bytes[i + 6] == 0x1A
                and video_bytes[i + 7] == 0x0A
            ):
                img_starts.append(i)
                i += 8  # Skip the PNG signature
            else:
                i += 1
    else:
        # Find each JPEG start (0xFFD8) to isolate images
        i = 0
        while (
            i < len(video_bytes) - 1
        ):  # Adjusted for the length of the JPEG SOI signature
            # Check if we found the start of a JPEG file
            if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
                img_starts.append(i)
                # Move to the next byte to continue searching for the next image start
                i += 2
            else:
                i += 1

    frames = []
    for start_idx in img_starts:
        # Assuming each image is back-to-back, the end of one image is the start of another
        # The last image goes until the end of the byte string
        end_idx = (
            img_starts[img_starts.index(start_idx) + 1]
            if img_starts.index(start_idx) + 1 < len(img_starts)
            else len(video_bytes)
        )
        img_bytes = video_bytes[start_idx:end_idx]

        # Convert bytes to a PIL Image
        img = Image.open(BytesIO(img_bytes))

        # Convert PIL Image to a NumPy array
        frame = np.array(img)

        # Append the frame to the list of frames
        frames.append(frame)

    # Ensure there's at least one frame to avoid errors with np.stack
    if frames:
        return np.stack(frames, axis=0), img.size
    else:
        return np.array([]), (
            0,
            0,
        )  # Return an empty array and size tuple if no frames were found
Lianmin Zheng's avatar
Lianmin Zheng committed
420
421


422
def load_image(image_file: Union[str, bytes]):
Lianmin Zheng's avatar
Lianmin Zheng committed
423
424
    from PIL import Image

Yuanhan Zhang's avatar
Yuanhan Zhang committed
425
    image = image_size = None
Lianmin Zheng's avatar
Lianmin Zheng committed
426

427
428
429
    if isinstance(image_file, bytes):
        image = Image.open(BytesIO(image_file))
    elif image_file.startswith("http://") or image_file.startswith("https://"):
Lianmin Zheng's avatar
Lianmin Zheng committed
430
431
432
433
434
435
        timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
        response = requests.get(image_file, timeout=timeout)
        image = Image.open(BytesIO(response.content))
    elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
        image = Image.open(image_file)
    elif image_file.startswith("data:"):
436
        image_file = image_file.split(",")[1]
Lianmin Zheng's avatar
Lianmin Zheng committed
437
        image = Image.open(BytesIO(base64.b64decode(image_file)))
Yuanhan Zhang's avatar
Yuanhan Zhang committed
438
439
440
    elif image_file.startswith("video:"):
        image_file = image_file.replace("video:", "")
        image, image_size = decode_video_base64(image_file)
441
    elif isinstance(image_file, str):
Lianmin Zheng's avatar
Lianmin Zheng committed
442
        image = Image.open(BytesIO(base64.b64decode(image_file)))
443
444
    else:
        raise ValueError(f"Invalid image: {image}")
Lianmin Zheng's avatar
Lianmin Zheng committed
445

Mick's avatar
Mick committed
446
447
    # if image_size is None:
    #     image_size = image.size
Yuanhan Zhang's avatar
Yuanhan Zhang committed
448
    return image, image_size
449
450


451
452
453
454
def suppress_other_loggers():
    from vllm.logger import logger as vllm_default_logger

    vllm_default_logger.setLevel(logging.WARN)
455
456
457
    logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
        logging.WARN
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
458
459
460
    logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
        logging.WARN
    )
461

462
463
464
465
    warnings.filterwarnings(
        "ignore", category=UserWarning, message="The given NumPy array is not writable"
    )

466

467
def assert_pkg_version(pkg: str, min_version: str, message: str):
468
469
470
471
    try:
        installed_version = version(pkg)
        if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
            raise Exception(
472
                f"{pkg} is installed with version {installed_version}, which "
Ying Sheng's avatar
Ying Sheng committed
473
                f"is less than the minimum required version {min_version}. " + message
474
475
            )
    except PackageNotFoundError:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
476
        raise Exception(
Ying Sheng's avatar
Ying Sheng committed
477
478
            f"{pkg} with minimum required version {min_version} is not installed. "
            + message
Yuanhan Zhang's avatar
Yuanhan Zhang committed
479
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
480
481


482
483
484
485
486
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
    """Kill the process and all its child processes."""
    if parent_pid is None:
        parent_pid = os.getpid()
        include_parent = False
Lianmin Zheng's avatar
Lianmin Zheng committed
487

488
    try:
489
        itself = psutil.Process(parent_pid)
490
491
492
    except psutil.NoSuchProcess:
        return

Lianmin Zheng's avatar
Lianmin Zheng committed
493
    children = itself.children(recursive=True)
494
    for child in children:
495
496
        if child.pid == skip_pid:
            continue
497
498
499
500
501
        try:
            child.kill()
        except psutil.NoSuchProcess:
            pass

502
    if include_parent:
503
        try:
Lianmin Zheng's avatar
Lianmin Zheng committed
504
            itself.kill()
505
506
507

            # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
            # so we send an additional signal to kill them.
508
            itself.send_signal(signal.SIGQUIT)
509
510
511
512
        except psutil.NoSuchProcess:
            pass


513
def monkey_patch_p2p_access_check():
514
    """
515
    Monkey patch the slow p2p access check.
516
517
518
    NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
    """

519
    import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt
Liangsheng Yin's avatar
Liangsheng Yin committed
520

521
    setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
522

Lianmin Zheng's avatar
Lianmin Zheng committed
523
    # Suppress the warnings from this delete function when using sglang.bench_one_batch
524
525
526
    from sglang.srt.distributed.device_communicators.custom_all_reduce import (
        CustomAllreduce,
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
527
528
529

    setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)

530

531
532
533
534
535
536
537
def monkey_patch_vllm_gguf_config():
    from vllm.model_executor.layers.quantization.gguf import (
        GGUFConfig,
        GGUFEmbeddingMethod,
        GGUFLinearMethod,
    )

Yineng Zhang's avatar
Yineng Zhang committed
538
    from sglang.srt.layers.linear import LinearBase
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding

    def get_quant_method_with_embedding_replaced(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        if isinstance(layer, LinearBase):
            return GGUFLinearMethod(self)
        elif isinstance(layer, VocabParallelEmbedding):
            # patch to own VocabParallelEmbedding
            return GGUFEmbeddingMethod(self)
        return None

    setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)


554
555
556
557
558
559
def maybe_set_triton_cache_manager() -> None:
    """Set environment variable to tell Triton to use a
    custom cache manager"""
    cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
    if cache_manger is None:
        manager = "sglang.srt.utils:CustomCacheManager"
560
        logger.debug("Setting Triton cache manager to: %s", manager)
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
        os.environ["TRITON_CACHE_MANAGER"] = manager


class CustomCacheManager(FileCacheManager):
    # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
    def __init__(self, key, override=False, dump=False):

        self.key = key
        self.lock_path = None
        if dump:
            self.cache_dir = default_dump_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
            self.lock_path = os.path.join(self.cache_dir, "lock")
            os.makedirs(self.cache_dir, exist_ok=True)
        elif override:
            self.cache_dir = default_override_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
        else:
            # create cache directory if it doesn't exist
            self.cache_dir = (
                os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
            )
            if self.cache_dir:
                self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
                self.cache_dir = os.path.join(self.cache_dir, self.key)
                self.lock_path = os.path.join(self.cache_dir, "lock")
                os.makedirs(self.cache_dir, exist_ok=True)
            else:
                raise RuntimeError("Could not create or locate cache dir")


592
593
594
595
596
597
598
599
def set_ulimit(target_soft_limit=65535):
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
        except ValueError as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
600
            logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
601
602


603
def add_api_key_middleware(app, api_key: str):
604
605
606
607
608
609
610
    @app.middleware("http")
    async def authentication(request, call_next):
        if request.method == "OPTIONS":
            return await call_next(request)
        if request.url.path.startswith("/health"):
            return await call_next(request)
        if request.headers.get("Authorization") != "Bearer " + api_key:
611
            return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
612
        return await call_next(request)
613
614


615
def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
616
    if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
617
618
619
        if not os.path.exists(model_path):
            from modelscope import snapshot_download

620
621
            model_path = snapshot_download(model_path)
            tokenizer_path = snapshot_download(
622
623
                tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
            )
624
    return model_path, tokenizer_path
625
626
627
628


def configure_logger(server_args, prefix: str = ""):
    format = f"[%(asctime)s{prefix}] %(message)s"
Lianmin Zheng's avatar
Lianmin Zheng committed
629
    # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
630
631
632
    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format=format,
633
        datefmt="%Y-%m-%d %H:%M:%S",
634
635
        force=True,
    )
636
637
638
639
640
641
642
643
644
645
646


# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666


def set_weight_attrs(
    weight: torch.Tensor,
    weight_attrs: Optional[Dict[str, Any]],
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
        assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
        setattr(weight, key, value)
667
668
669


def broadcast_pyobj(
670
671
672
    data: List[Any],
    rank: int,
    dist_group: Optional[torch.distributed.ProcessGroup] = None,
673
    src: int = 0,
674
675
676
677
678
679
):
    """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""

    if rank == 0:
        if len(data) == 0:
            tensor_size = torch.tensor([0], dtype=torch.long)
680
            dist.broadcast(tensor_size, src=src, group=dist_group)
681
682
683
        else:
            serialized_data = pickle.dumps(data)
            size = len(serialized_data)
684
685
686
            tensor_data = torch.ByteTensor(
                np.frombuffer(serialized_data, dtype=np.uint8)
            )
687
688
            tensor_size = torch.tensor([size], dtype=torch.long)

689
690
            dist.broadcast(tensor_size, src=src, group=dist_group)
            dist.broadcast(tensor_data, src=src, group=dist_group)
691
692
693
        return data
    else:
        tensor_size = torch.tensor([0], dtype=torch.long)
694
        dist.broadcast(tensor_size, src=src, group=dist_group)
695
696
697
698
699
700
        size = tensor_size.item()

        if size == 0:
            return []

        tensor_data = torch.empty(size, dtype=torch.uint8)
701
        dist.broadcast(tensor_data, src=src, group=dist_group)
702

703
        serialized_data = bytes(tensor_data.cpu().numpy())
704
705
        data = pickle.loads(serialized_data)
        return data
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736


step_counter = 0


def pytorch_profile(name, func, *args, data_size=-1):
    """
    Args:
        name (string): the name of recorded function.
        func: the function to be profiled.
        args: the arguments of the profiled function.
        data_size (int): some measurement of the computation complexity.
            Usually, it could be the batch size.
    """
    global step_counter
    os.makedirs("trace", exist_ok=True)
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
        # on_trace_ready=tensorboard_trace_handler('./log_dir'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as prof:
        with record_function(name):
            with open(f"trace/size_{step_counter}.json", "w") as f:
                json.dump({"size": data_size}, f)
            result = func(*args)
    prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
    step_counter += 1
    return result
737
738
739
740
741
742
743


def first_rank_print(*args, **kwargs):
    if torch.cuda.current_device() == 0:
        print(*args, **kwargs)
    else:
        pass
744
745


Lianmin Zheng's avatar
Lianmin Zheng committed
746
747
748
def get_zmq_socket(
    context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
):
749
750
751
752
753
754
755
756
    mem = psutil.virtual_memory()
    total_mem = mem.total / 1024**3
    available_mem = mem.available / 1024**3
    if total_mem > 32 and available_mem > 16:
        buf_size = int(0.5 * 1024**3)
    else:
        buf_size = -1

757
758
759
    socket = context.socket(socket_type)
    if socket_type == zmq.PUSH:
        socket.setsockopt(zmq.SNDHWM, 0)
760
        socket.setsockopt(zmq.SNDBUF, buf_size)
761
762
    elif socket_type == zmq.PULL:
        socket.setsockopt(zmq.RCVHWM, 0)
763
        socket.setsockopt(zmq.RCVBUF, buf_size)
764
765
766
    else:
        raise ValueError(f"Unsupported socket type: {socket_type}")

Lianmin Zheng's avatar
Lianmin Zheng committed
767
768
769
770
771
    if bind:
        socket.bind(endpoint)
    else:
        socket.connect(endpoint)

772
    return socket
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813


def dump_to_file(dirpath, name, value):
    from vllm.distributed import get_tensor_model_parallel_rank

    if get_tensor_model_parallel_rank() != 0:
        return

    os.makedirs(dirpath, exist_ok=True)
    if value.dtype is torch.bfloat16:
        value = value.float()
    value = value.cpu().numpy()
    output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
    logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
    np.save(output_filename, value)


def is_triton_3():
    return triton.__version__.startswith("3.")


def maybe_torch_compile(*args, **kwargs):
    """
    torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
    Therefore, we disable it here.
    """

    def decorator(func):
        if is_triton_3():
            return torch.compile(*args, **kwargs)(func)
        return func

    return decorator


def delete_directory(dirpath):
    try:
        # This will remove the directory and all its contents
        shutil.rmtree(dirpath)
    except OSError as e:
        print(f"Warning: {dirpath} : {e.strerror}")
Lianmin Zheng's avatar
Lianmin Zheng committed
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839


# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir: tempfile.TemporaryDirectory


def set_prometheus_multiproc_dir():
    # Set prometheus multiprocess directory
    # sglang uses prometheus multiprocess mode
    # we need to set this before importing prometheus_client
    # https://prometheus.github.io/client_python/multiprocess/
    global prometheus_multiproc_dir

    if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
        logger.debug("User set PROMETHEUS_MULTIPROC_DIR detected.")
        prometheus_multiproc_dir = tempfile.TemporaryDirectory(
            dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
        )
    else:
        prometheus_multiproc_dir = tempfile.TemporaryDirectory()
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
    logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")


def add_prometheus_middleware(app):
840
    # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
Lianmin Zheng's avatar
Lianmin Zheng committed
841
842
843
844
845
846
847
848
849
    from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess

    registry = CollectorRegistry()
    multiprocess.MultiProcessCollector(registry)
    metrics_route = Mount("/metrics", make_asgi_app(registry=registry))

    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
    app.routes.append(metrics_route)
850
851


852
853
854
855
856
857
858
859
860
def bind_port(port):
    """Bind to a specific port, assuming it's available."""
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)  # Allows address reuse
    sock.bind(("", port))
    sock.listen(1)
    return sock


HAI's avatar
HAI committed
861
862
863
864
def get_amdgpu_memory_capacity():
    try:
        # Run rocm-smi and capture the output
        result = subprocess.run(
865
            [
HAI's avatar
HAI committed
866
                "rocminfo | grep 'gfx' -A 100 | grep 'Pool 1' -A 5 | grep 'Size:' | awk '{print $2}'"
867
            ],
HAI's avatar
HAI committed
868
869
870
871
872
873
874
875
876
877
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            shell=True,
            text=True,
        )
        if result.returncode != 0:
            raise RuntimeError(f"rocm-smi error: {result.stderr.strip()}")

        # Parse the output to extract memory values in MiB
        memory_values = [
878
            float(mem.split("(")[0].strip()) / 1024
HAI's avatar
HAI committed
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
            for mem in result.stdout.strip().split("\n")
        ]

        if not memory_values:
            raise ValueError("No GPU memory values found.")

        # Return the minimum memory value
        return min(memory_values)

    except FileNotFoundError:
        raise RuntimeError(
            "rocm-smi not found. Ensure AMD ROCm drivers are installed and accessible."
        )


def get_nvgpu_memory_capacity():
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
    try:
        # Run nvidia-smi and capture the output
        result = subprocess.run(
            ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )

        if result.returncode != 0:
            raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}")

        # Parse the output to extract memory values
        memory_values = [
            float(mem)
            for mem in result.stdout.strip().split("\n")
            if re.match(r"^\d+(\.\d+)?$", mem.strip())
        ]

        if not memory_values:
            raise ValueError("No GPU memory values found.")

        # Return the minimum memory value
        return min(memory_values)

    except FileNotFoundError:
        raise RuntimeError(
            "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
        )
924
925


926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
def get_hpu_memory_capacity():
    try:
        # Run hl-smi and capture the output
        result = subprocess.run(
            ["hl-smi --query | grep 'Total'"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            shell=True,
            text=True,
        )

        if result.returncode != 0:
            raise RuntimeError(f"hl-smi error: {result.stderr.strip()}")

        # Parse the output to extract memory values in MiB
        memory_values = [
            float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n")
        ]

        if not memory_values:
            raise ValueError("No GPU memory values found.")

        # Return the minimum memory value
        return min(memory_values)

    except FileNotFoundError:
        raise RuntimeError(
            "hl-smi not found. Ensure Habana drivers are installed and accessible."
        )


957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
def init_custom_process_group(
    backend=None,
    init_method=None,
    timeout=None,
    world_size=-1,
    rank=-1,
    store=None,
    group_name=None,
    pg_options=None,
):
    from torch.distributed.distributed_c10d import (
        Backend,
        PrefixStore,
        _new_process_group_helper,
        _world,
        default_pg_timeout,
        rendezvous,
    )

    assert (store is None) or (
        init_method is None
    ), "Cannot specify both init_method and store."

    if store is not None:
        assert world_size > 0, "world_size must be positive if using store"
        assert rank >= 0, "rank must be non-negative if using store"
    elif init_method is None:
        init_method = "env://"

    if backend:
        backend = Backend(backend)
    else:
        backend = Backend("undefined")

    if timeout is None:
        timeout = default_pg_timeout

    # backward compatible API
    if store is None:
        rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
        store, rank, world_size = next(rendezvous_iterator)
        store.set_timeout(timeout)

        # Use a PrefixStore to avoid accidental overrides of keys used by
        # different systems (e.g. RPC) in case the store is multi-tenant.
        store = PrefixStore(group_name, store)

    # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
    # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
    # We need to determine the appropriate parameter name based on PyTorch version
    pg_options_param_name = (
        "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
    )
    pg, _ = _new_process_group_helper(
        world_size,
        rank,
        [],
        backend,
        store,
        group_name=group_name,
        **{pg_options_param_name: pg_options},
        timeout=timeout,
    )

    _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}

    return pg


1029
1030
def crash_on_warnings():
    # Crash on warning if we are running CI tests
1031
    return get_bool_env_var("SGLANG_IS_IN_CI")
1032
1033


1034
1035
1036
1037
1038
def print_warning_once(msg: str) -> None:
    # Set the stacklevel to 2 to print the caller's line info
    logger.warning(msg, stacklevel=2)


1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
def get_device_name(device_id: int = 0) -> str:
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        return torch.cuda.get_device_name(device_id)

    if hasattr(torch, "xpu") and torch.xpu.is_available():
        return torch.xpu.get_device_name(device_id)

    if hasattr(torch, "hpu") and torch.hpu.is_available():
        return torch.hpu.get_device_name(device_id)


1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
    major, minor = None, None
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability(device_id)

    if hasattr(torch, "xpu") and torch.xpu.is_available():
        major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
            "."
        )
        major, minor = int(major), int(minor)

    # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
    # Update this once the support is available.
    if hasattr(torch, "hpu") and torch.hpu.is_available():
        try:
            major, minor = torch.hpu.get_device_capability(device_id)
        except Exception as e:
            raise RuntimeError(
                f"An error occurred while getting device capability of hpu: {e}."
            ) from e

    return major, minor


1074
1075
1076
1077
1078
1079
1080
def get_compiler_backend() -> str:
    if hasattr(torch, "hpu") and torch.hpu.is_available():
        return "hpu_backend"

    return "inductor"


1081
1082
1083
sglang_lib = Library("sglang", "FRAGMENT")  # noqa


1084
1085
1086
1087
1088
1089
# Some backends use pytorch version < 2.4.0 which doesn't
# support `torch.library.custom_op`.
def supports_custom_op() -> bool:
    return hasattr(torch.library, "custom_op")


1090
1091
1092
1093
1094
1095
1096
def direct_register_custom_op(
    op_name: str,
    op_func: Callable,
    mutates_args: List[str],
    fake_impl: Optional[Callable] = None,
    target_lib: Optional[Library] = None,
):
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
    """
    `torch.library.custom_op` can have significant overhead because it
    needs to consider complicated dispatching logic. This function
    directly registers a custom op and dispatches it to the CUDA backend.
    See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
    for more details.

    By default, the custom op is registered to the vLLM library. If you
    want to register it to a different library, you can pass the library
    object to the `target_lib` argument.

    IMPORTANT: the lifetime of the operator is tied to the lifetime of the
    library object. If you want to bind the operator to a different library,
    make sure the library object is alive when the operator is used.
    """
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
    import torch.library

    if hasattr(torch.library, "infer_schema"):
        schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
    else:
        # for pytorch 2.4
        import torch._custom_op.impl

        schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)

    my_lib = target_lib or sglang_lib
    my_lib.define(op_name + schema_str)
    my_lib.impl(op_name, op_func, "CUDA")
    if fake_impl is not None:
        my_lib._register_fake(op_name, fake_impl)
1127
1128


1129
def set_gpu_proc_affinity(
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
    tp_size: int,
    nnodes: int,
    gpu_id: int,
):
    # current process
    pid = os.getpid()
    p = psutil.Process(pid)

    tp_size_per_node = tp_size // nnodes

    # total physical cores
    total_pcores = psutil.cpu_count(logical=False)
    # physical cores per TP (N.B. more Cores than GPUs on node)
    num_cores_bind = total_pcores // tp_size_per_node

    # able to handle multiple DP per node
    start_cpu_id = (gpu_id * num_cores_bind) % total_pcores
    end_cpu_id = start_cpu_id + num_cores_bind

    if psutil.cpu_count() != psutil.cpu_count(logical=False):
        # HT on
        upper_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
        lower_cpu_ids = [id + total_pcores for id in range(start_cpu_id, end_cpu_id)]
        bind_cpu_ids = list(itertools.chain(upper_cpu_ids, lower_cpu_ids))
    else:
        # HT off
        bind_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]

    # set cpu_affinity to current process
    p.cpu_affinity(bind_cpu_ids)
    logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")
1161
1162
1163
1164
1165


def get_bool_env_var(name: str, default: str = "false") -> bool:
    value = os.getenv(name, default)
    return value.lower() in ("true", "1")
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206


@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
    # Note: cuda_visible_devices is not used, but we keep it as an argument for
    # LRU Cache purposes.

    # Code below is based on
    # https://github.com/pytorch/pytorch/blob/
    # c1cd946818442aca8c7f812b16d187ce1586c3bc/
    # torch/cuda/__init__.py#L831C1-L831C17
    import torch.version

    if not torch.cuda._is_compiled():
        return 0
    if is_hip():
        # ROCm uses amdsmi instead of nvml for stateless device count
        # This requires a sufficiently modern version of Torch 2.4.0
        raw_count = (
            torch.cuda._device_count_amdsmi()
            if (hasattr(torch.cuda, "_device_count_amdsmi"))
            else -1
        )
    else:
        raw_count = torch.cuda._device_count_nvml()
    r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
    return r


# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py
def cuda_device_count_stateless() -> int:
    """Get number of CUDA devices, caching based on the value of
    CUDA_VISIBLE_DEVICES at the time of call.

    This should be used instead of torch.cuda.device_count()
    unless CUDA_VISIBLE_DEVICES has already been set to the desired
    value."""

    # This can be removed and simply replaced with torch.cuda.get_device_count
    # after https://github.com/pytorch/pytorch/pull/122815 is released.
    return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
1207
1208


1209
1210
1211
1212
def dataclass_to_string_truncated(data, max_length=2048):
    if isinstance(data, str):
        if len(data) > max_length:
            half_length = max_length // 2
1213
            return f"{repr(data[:half_length])} ... {repr(data[-half_length:])}"
1214
        else:
1215
            return f"{repr(data)}"
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
    elif isinstance(data, (list, tuple)):
        if len(data) > max_length:
            half_length = max_length // 2
            return str(data[:half_length]) + " ... " + str(data[-half_length:])
        else:
            return str(data)
    elif isinstance(data, dict):
        return (
            "{"
            + ", ".join(
1226
                f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
                for k, v in data.items()
            )
            + "}"
        )
    elif dataclasses.is_dataclass(data):
        fields = dataclasses.fields(data)
        return (
            f"{data.__class__.__name__}("
            + ", ".join(
                f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
                for f in fields
            )
            + ")"
        )
1241
    else:
1242
        return str(data)
Tanjiro's avatar
Tanjiro committed
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304


TOOLS_TAG_LIST = ["<|plugin|>", "<function=", "<tool_call>", "<|python_tag|>"]


def parse_tool_response(text, tools, **kwargs):
    """Parse model response containing tool information.

    Args:
        text(str): model response in string format
        tools(List): tools from user request
    """
    if "<|plugin|>" in text:  # internlm2
        text, action = text.split("<|action_start|><|plugin|>")
        action = action.split("<|action_end|>".strip())[0]
        action = action[action.find("{") :]
        action = json.loads(action)
        name, parameters = action["name"], json.dumps(
            action.get("parameters", action.get("arguments", {})), ensure_ascii=False
        )
        call_info_list = [(name, parameters)]
    elif "<function=" in text:  # llama3.1
        action, _ = text.split("</function>")
        parameters = action[action.find("{") :]
        name = action.split("<function=")[1].split(">{")[0]
        call_info_list = [(name, parameters)]
    elif "<tool_call>" in text and "</tool_call>" in text:  # qwen2.5
        # get tool_call in text
        pattern = r"<tool_call>(.*?)</tool_call>"
        match_result_list = re.findall(pattern, text, re.DOTALL)
        call_info_list = []
        for match_result in match_result_list:
            action = json.loads(match_result)
            call_info_list.append(
                (action["name"], json.dumps(action["arguments"], ensure_ascii=False))
            )
        # get text outside of tags
        if not text.startswith("<tool_call>"):
            text = text[: text.find("<tool_call>")]
        elif not text.endswith("</tool_call>"):
            text = text[text.rfind("</tool_call>") + len("</tool_call>") :]
        else:
            text = ""
    elif "<|python_tag|>" in text:  # llama3.2
        _, action = text.split("<|python_tag|>")
        action = json.loads(action)
        name, parameters = action["name"], json.dumps(
            action.get("parameters", action.get("arguments", {})), ensure_ascii=False
        )
        call_info_list = [(name, parameters)]
    else:
        raise RuntimeError(f"Unexpected model response: {text}")

    call_info_list = [
        (
            [tool.function.name for tool in tools].index(call_info[0]),
            call_info[0],
            call_info[1],
        )
        for call_info in call_info_list
    ]
    return text, call_info_list
1305
1306


1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
def permute_weight(x: torch.Tensor) -> torch.Tensor:
    b_ = x.shape[0]
    n_ = x.shape[1]
    k_ = x.shape[2]

    x_ = x
    if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
        x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
    elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
        x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
    else:
        return x_

    x_ = x_.permute(0, 1, 3, 4, 2, 5)
    x_ = x_.contiguous()
    x_ = x_.view(*x.shape)
    return x_


1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
class MultiprocessingSerializer:
    @staticmethod
    def serialize(obj):
        buf = io.BytesIO()
        ForkingPickler(buf).dump(obj)
        buf.seek(0)
        return buf.read()

    @staticmethod
    def deserialize(data):
        return ForkingPickler.loads(data)
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360


def debug_timing(func):
    # todo: replace with a more organized instrumentation
    def wrapper(*args, **kwargs):
        if logger.isEnabledFor(logging.DEBUG):
            tic = torch.cuda.Event(enable_timing=True)
            toc = torch.cuda.Event(enable_timing=True)
            tic.record()
            result = func(*args, **kwargs)
            toc.record()
            torch.cuda.synchronize()  # Ensure all CUDA operations are complete
            elapsed = tic.elapsed_time(toc)
            indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
            num_tokens = len(indices) if indices is not None else 0
            throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0
            logger.debug(
                f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s"
            )
            return result
        else:
            return func(*args, **kwargs)

    return wrapper
bjmsong's avatar
bjmsong committed
1361
1362
1363
1364
1365
1366


def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377


def set_uvicorn_logging_configs():
    LOGGING_CONFIG["formatters"]["default"][
        "fmt"
    ] = "[%(asctime)s] %(levelprefix)s %(message)s"
    LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
    LOGGING_CONFIG["formatters"]["access"][
        "fmt"
    ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
    LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444


def get_ip() -> str:
    # SGLANG_HOST_IP env can be ignore
    host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

    # try ipv4
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
        pass

    # try ipv6
    try:
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
        "The value can be set by the environment variable"
        " SGLANG_HOST_IP or HOST_IP.",
        stacklevel=2,
    )
    return "0.0.0.0"


def get_open_port() -> int:

    port = os.getenv("SGLANG_PORT")
    if port is not None:
        while True:
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("", port))
                    return port
            except OSError:
                port += 1  # Increment port number if already in use
                logger.info("Port %d is already in use, trying port %d", port - 1, port)
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]


def is_valid_ipv6_address(address: str) -> bool:
    try:
        ipaddress.IPv6Address(address)
        return True
    except ValueError:
        return False