utils.py 66.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""Common utilities."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import base64
17
import builtins
18
import ctypes
19
import dataclasses
20
import importlib
21
import io
22
import ipaddress
23
import itertools
24
import json
25
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
26
import os
27
import pickle
Lianmin Zheng's avatar
Lianmin Zheng committed
28
import random
Lianmin Zheng's avatar
Lianmin Zheng committed
29
import re
30
import resource
31
32
import shutil
import signal
Lianmin Zheng's avatar
Lianmin Zheng committed
33
import socket
34
import subprocess
35
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
36
import tempfile
37
import threading
Lianmin Zheng's avatar
Lianmin Zheng committed
38
import time
39
import traceback
40
import warnings
41
from contextlib import contextmanager
42
from enum import Enum
43
from functools import lru_cache
44
from importlib.metadata import PackageNotFoundError, version
45
from importlib.util import find_spec
Lianmin Zheng's avatar
Lianmin Zheng committed
46
from io import BytesIO
47
from json import JSONDecodeError
48
from multiprocessing.reduction import ForkingPickler
49
from pathlib import Path
50
51
52
53
54
55
56
57
58
59
60
61
62
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    List,
    Optional,
    Protocol,
    Set,
    Tuple,
    TypeVar,
    Union,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
63
64

import numpy as np
65
import psutil
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
import requests
import torch
68
import torch.distributed
69
import torch.distributed as dist
70
import triton
71
import zmq
72
from fastapi.responses import ORJSONResponse
73
from packaging import version as pkg_version
Mick's avatar
Mick committed
74
from PIL import Image
Lianmin Zheng's avatar
Lianmin Zheng committed
75
from starlette.routing import Mount
76
from torch import nn
77
from torch.func import functional_call
78
from torch.library import Library
79
from torch.profiler import ProfilerActivity, profile, record_function
80
from torch.utils._contextlib import _DecoratorContextManager
81
82
83
84
85
86
from triton.runtime.cache import (
    FileCacheManager,
    default_cache_dir,
    default_dump_dir,
    default_override_dir,
)
87

88
89
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
90
91
show_time_cost = False
time_infos = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
92

93
94
HIP_FP8_E4M3_FNUZ_MAX = 224.0

95
96
_warned_bool_env_var_keys = set()

Lianmin Zheng's avatar
Lianmin Zheng committed
97

98
99
def get_bool_env_var(name: str, default: str = "false") -> bool:
    value = os.getenv(name, default)
100
101
102
103
104
105
106
107
108
109
110
111
112
    value = value.lower()

    truthy_values = ("true", "1")
    falsy_values = ("false", "0")

    if (value not in truthy_values) and (value not in falsy_values):
        if value not in _warned_bool_env_var_keys:
            logger.warning(
                f"get_bool_env_var({name}) see non-understandable value={value} and treat as false"
            )
        _warned_bool_env_var_keys.add(value)

    return value in truthy_values
113
114


115
116
117
118
119
120
121
122
123
124
def get_int_env_var(name: str, default: int = 0) -> int:
    value = os.getenv(name)
    if value is None or not value.strip():
        return default
    try:
        return int(value)
    except ValueError:
        return default


125
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
126
127
128
129
def is_hip() -> bool:
    return torch.version.hip is not None


130
131
132
133
134
135
136
137
138
139
140
if is_hip():
    FP8_E4M3_MAX = HIP_FP8_E4M3_FNUZ_MAX
else:
    FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max

FP8_E4M3_MIN = -FP8_E4M3_MAX

builtins.FP8_E4M3_MAX = FP8_E4M3_MAX
builtins.FP8_E4M3_MIN = FP8_E4M3_MIN


141
def is_cuda():
142
    return torch.cuda.is_available() and torch.version.cuda
143
144
145
146
147
148
149
150
151
152
153
154
155
156


def is_cuda_alike():
    return is_cuda() or is_hip()


def is_hpu() -> bool:
    return hasattr(torch, "hpu") and torch.hpu.is_available()


def is_xpu() -> bool:
    return hasattr(torch, "xpu") and torch.xpu.is_available()


157
158
159
160
def is_npu() -> bool:
    return hasattr(torch, "npu") and torch.npu.is_available()


161
162
163
164
165
def is_flashinfer_available():
    """
    Check whether flashinfer is available.
    As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
    """
166
    if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
167
        return False
168
    return importlib.util.find_spec("flashinfer") is not None and is_cuda()
169
170


171
_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
172
    "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
173
)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227


class DynamicGradMode(_DecoratorContextManager):
    """
    A combination of torch.no_grad and torch.inference_mode,
    with their behavior controlled by an environment variable. Just refer to them.
    """

    @staticmethod
    def set_inference_mode(mode: bool):
        if isinstance(mode, bool):
            global _ENABLE_TORCH_INFERENCE_MODE

            _ENABLE_TORCH_INFERENCE_MODE = mode
        else:
            logger.warning("mode is not a boolean object")

    def __init__(self, mode=True):
        if not torch._jit_internal.is_scripting():
            super().__init__()
        if _ENABLE_TORCH_INFERENCE_MODE:
            self.mode = mode
        else:
            self.prev = False

    def __new__(cls, mode_or_orig_func=True if _ENABLE_TORCH_INFERENCE_MODE else None):
        if mode_or_orig_func is None or isinstance(mode_or_orig_func, bool):
            return super().__new__(cls)
        return cls()(mode_or_orig_func)

    def __enter__(self) -> None:
        if _ENABLE_TORCH_INFERENCE_MODE:
            self._inference_mode_context = torch._C._InferenceMode(self.mode)
            self._inference_mode_context.__enter__()
        else:
            self.prev = torch.is_grad_enabled()
            torch.set_grad_enabled(False)

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        if _ENABLE_TORCH_INFERENCE_MODE:
            self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
        else:
            torch.set_grad_enabled(self.prev)

    def clone(self) -> "DynamicGradMode":
        r"""
        Create a copy of this class
        """
        if _ENABLE_TORCH_INFERENCE_MODE:
            return self.__class__(self.mode)
        else:
            return self.__class__()


Liangsheng Yin's avatar
Liangsheng Yin committed
228
229
230
231
def enable_show_time_cost():
    global show_time_cost
    show_time_cost = True

Lianmin Zheng's avatar
Lianmin Zheng committed
232

Liangsheng Yin's avatar
Liangsheng Yin committed
233
234
235
236
237
238
class TimeInfo:
    def __init__(self, name, interval=0.1, color=0, indent=0):
        self.name = name
        self.interval = interval
        self.color = color
        self.indent = indent
Lianmin Zheng's avatar
Lianmin Zheng committed
239

Liangsheng Yin's avatar
Liangsheng Yin committed
240
241
        self.acc_time = 0
        self.last_acc_time = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
242

Liangsheng Yin's avatar
Liangsheng Yin committed
243
244
245
246
247
    def check(self):
        if self.acc_time - self.last_acc_time > self.interval:
            self.last_acc_time = self.acc_time
            return True
        return False
Lianmin Zheng's avatar
Lianmin Zheng committed
248

Liangsheng Yin's avatar
Liangsheng Yin committed
249
250
251
252
    def pretty_print(self):
        print(f"\x1b[{self.color}m", end="")
        print("-" * self.indent * 2, end="")
        print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
Lianmin Zheng's avatar
Lianmin Zheng committed
253
254


Liangsheng Yin's avatar
Liangsheng Yin committed
255
256
257
258
def mark_start(name, interval=0.1, color=0, indent=0):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
259
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
260
261
    if time_infos.get(name, None) is None:
        time_infos[name] = TimeInfo(name, interval, color, indent)
262
    time_infos[name].acc_time -= time.perf_counter()
Lianmin Zheng's avatar
Lianmin Zheng committed
263
264


Liangsheng Yin's avatar
Liangsheng Yin committed
265
266
267
268
def mark_end(name):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
269
    torch.cuda.synchronize()
270
    time_infos[name].acc_time += time.perf_counter()
Liangsheng Yin's avatar
Liangsheng Yin committed
271
272
    if time_infos[name].check():
        time_infos[name].pretty_print()
Lianmin Zheng's avatar
Lianmin Zheng committed
273
274
275
276
277
278
279


def calculate_time(show=False, min_cost_ms=0.0):
    def wrapper(func):
        def inner_func(*args, **kwargs):
            torch.cuda.synchronize()
            if show:
280
                start_time = time.perf_counter()
Lianmin Zheng's avatar
Lianmin Zheng committed
281
282
283
            result = func(*args, **kwargs)
            torch.cuda.synchronize()
            if show:
284
                cost_time = (time.perf_counter() - start_time) * 1000
Lianmin Zheng's avatar
Lianmin Zheng committed
285
286
287
288
289
290
291
292
293
                if cost_time > min_cost_ms:
                    print(f"Function {func.__name__} took {cost_time} ms to run.")
            return result

        return inner_func

    return wrapper


294
295
296
def get_available_gpu_memory(
    device, gpu_id, distributed=False, empty_cache=True, cpu_group=None
):
Lianmin Zheng's avatar
Lianmin Zheng committed
297
298
299
300
    """
    Get available memory for cuda:gpu_id device.
    When distributed is True, the available memory is the minimum available memory of all GPUs.
    """
Zhang, Liangang's avatar
Zhang, Liangang committed
301
    if device == "cuda":
302
        num_gpus = torch.cuda.device_count()
Zhang, Liangang's avatar
Zhang, Liangang committed
303
304
305
306
307
308
309
310
        assert gpu_id < num_gpus

        if torch.cuda.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
                "which may cause useless memory allocation for torch CUDA context.",
            )

311
312
        if empty_cache:
            torch.cuda.empty_cache()
Zhang, Liangang's avatar
Zhang, Liangang committed
313
314
315
316
317
318
319
320
321
322
323
        free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)

    elif device == "xpu":
        num_gpus = torch.xpu.device_count()
        assert gpu_id < num_gpus

        if torch.xpu.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
                "which may cause useless memory allocation for torch XPU context.",
            )
324
325
326

        if empty_cache:
            torch.xpu.empty_cache()
Zhang, Liangang's avatar
Zhang, Liangang committed
327
328
329
        used_memory = torch.xpu.memory_allocated()
        total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
        free_gpu_memory = total_gpu_memory - used_memory
Lianmin Zheng's avatar
Lianmin Zheng committed
330

331
332
333
334
335
336
337
338
339
340
341
342
    elif device == "hpu":
        num_gpus = torch.hpu.device_count()
        assert gpu_id < num_gpus

        if torch.hpu.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ",
                "which may cause useless memory allocation for torch HPU context.",
            )

        free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()

343
344
345
    elif device == "cpu":
        # TODO: rename the variables in the current function to be not GPU specific
        free_gpu_memory = psutil.virtual_memory().available
346
347
348
349
350
351
352
353
354
355
    elif device == "npu":
        num_gpus = torch.npu.device_count()
        assert gpu_id < num_gpus

        if torch.npu.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ",
                "which may cause useless memory allocation for torch NPU context.",
            )
        free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
356

Lianmin Zheng's avatar
Lianmin Zheng committed
357
    if distributed:
358
359
360
        tensor = torch.tensor(free_gpu_memory, dtype=torch.float32)
        torch.distributed.all_reduce(
            tensor, op=torch.distributed.ReduceOp.MIN, group=cpu_group
Lianmin Zheng's avatar
Lianmin Zheng committed
361
362
363
364
365
366
        )
        free_gpu_memory = tensor.item()

    return free_gpu_memory / (1 << 30)


367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
def is_pin_memory_available() -> bool:
    return torch.cuda.is_available()


_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
    device = next(module.parameters()).device

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()
    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
    offloaded_parameters = False
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
        cpu_data = torch.empty_strided(
            size=p.data.size(),
            stride=p.data.stride(),
            dtype=p.data.dtype,
            layout=p.data.layout,
            device="cpu",
            pin_memory=pin_memory,
        )
        cpu_data.copy_(p.data)
        p.data = cpu_data
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
        offloaded_parameters = True

    if offloaded_parameters:
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
            output = functional_call(module, device_state, args=args, kwargs=kwargs)
            module.forward = forward
            return output

        module.forward = forward

    return module


class LayerFn(Protocol):

    def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...


def make_layers(
    num_hidden_layers: int,
    layer_fn: LayerFn,
443
444
    pp_rank: Optional[int] = None,
    pp_size: Optional[int] = None,
445
    prefix: str = "",
446
    return_tuple: bool = False,
447
448
) -> Tuple[int, int, torch.nn.ModuleList]:
    """Make a list of layers with the given layer function"""
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    # circula imports
    from sglang.srt.distributed import get_pp_indices
    from sglang.srt.layers.utils import PPMissingLayer

    assert not pp_size or num_hidden_layers >= pp_size
    start_layer, end_layer = (
        get_pp_indices(
            num_hidden_layers,
            pp_rank,
            pp_size,
        )
        if pp_rank is not None and pp_size is not None
        else (0, num_hidden_layers)
    )
463
    modules = torch.nn.ModuleList(
464
465
        [PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)]
        + [
466
            maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
467
468
469
470
471
            for idx in range(start_layer, end_layer)
        ]
        + [
            PPMissingLayer(return_tuple=return_tuple)
            for _ in range(end_layer, num_hidden_layers)
472
473
        ]
    )
474
475
476
    if pp_rank is None or pp_size is None:
        return modules
    return modules, start_layer, end_layer
477
478


Lianmin Zheng's avatar
Lianmin Zheng committed
479
def set_random_seed(seed: int) -> None:
480
    """Set the random seed for all libraries."""
Lianmin Zheng's avatar
Lianmin Zheng committed
481
    random.seed(seed)
482
    np.random.seed(seed)
Lianmin Zheng's avatar
Lianmin Zheng committed
483
484
485
486
487
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


488
def is_port_available(port):
489
    """Return whether a port is available."""
490
491
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
492
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
493
            s.bind(("", port))
494
            s.listen(1)
495
496
497
            return True
        except socket.error:
            return False
TianYu GUO's avatar
TianYu GUO committed
498
499
        except OverflowError:
            return False
500
501


Yuanhan Zhang's avatar
Yuanhan Zhang committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
def decode_video_base64(video_base64):
    from PIL import Image

    # Decode the base64 string
    video_bytes = base64.b64decode(video_base64)

    # Placeholder for the start indices of each PNG image
    img_starts = []

    frame_format = "PNG"  # str(os.getenv('FRAME_FORMAT', "JPEG"))

    assert frame_format in [
        "PNG",
        "JPEG",
    ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"

    if frame_format == "PNG":
        # Find each PNG start signature to isolate images
        i = 0
        while i < len(video_bytes) - 7:  # Adjusted for the length of the PNG signature
            # Check if we found the start of a PNG file
            if (
                video_bytes[i] == 0x89
                and video_bytes[i + 1] == 0x50
                and video_bytes[i + 2] == 0x4E
                and video_bytes[i + 3] == 0x47
                and video_bytes[i + 4] == 0x0D
                and video_bytes[i + 5] == 0x0A
                and video_bytes[i + 6] == 0x1A
                and video_bytes[i + 7] == 0x0A
            ):
                img_starts.append(i)
                i += 8  # Skip the PNG signature
            else:
                i += 1
    else:
        # Find each JPEG start (0xFFD8) to isolate images
        i = 0
        while (
            i < len(video_bytes) - 1
        ):  # Adjusted for the length of the JPEG SOI signature
            # Check if we found the start of a JPEG file
            if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
                img_starts.append(i)
                # Move to the next byte to continue searching for the next image start
                i += 2
            else:
                i += 1

    frames = []
    for start_idx in img_starts:
        # Assuming each image is back-to-back, the end of one image is the start of another
        # The last image goes until the end of the byte string
        end_idx = (
            img_starts[img_starts.index(start_idx) + 1]
            if img_starts.index(start_idx) + 1 < len(img_starts)
            else len(video_bytes)
        )
        img_bytes = video_bytes[start_idx:end_idx]

        # Convert bytes to a PIL Image
        img = Image.open(BytesIO(img_bytes))

        # Convert PIL Image to a NumPy array
        frame = np.array(img)

        # Append the frame to the list of frames
        frames.append(frame)

    # Ensure there's at least one frame to avoid errors with np.stack
    if frames:
        return np.stack(frames, axis=0), img.size
    else:
        return np.array([]), (
            0,
            0,
        )  # Return an empty array and size tuple if no frames were found
Lianmin Zheng's avatar
Lianmin Zheng committed
579
580


Mick's avatar
Mick committed
581
582
583
584
585
586
587
588
589
590
591
592
def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray:
    # Use soundfile here, since librosa use it under the hood,
    # and librosa will not support audio loading in the future
    import soundfile as sf
    from scipy.signal import resample

    # Load audio data
    if isinstance(audio_file, bytes):
        audio, original_sr = sf.read(BytesIO(audio_file))
    elif audio_file.startswith("data:"):
        audio_file = audio_file.split(",")[1]
        audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
Mick's avatar
Mick committed
593
594
595
596
597
598
    elif audio_file.startswith("http://") or audio_file.startswith("https://"):
        timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
        response = requests.get(audio_file, stream=True, timeout=timeout)
        audio_file = BytesIO(response.content)
        response.close()
        audio, original_sr = sf.read(audio_file)
Mick's avatar
Mick committed
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
    elif isinstance(audio_file, str):
        audio, original_sr = sf.read(audio_file)
    else:
        raise ValueError(f"Invalid audio format: {audio_file}")

    # Resample audio if the original sample rate is different from the desired sample rate
    if original_sr != sr:
        num_samples = int(len(audio) * float(sr) / original_sr)
        audio = resample(audio, num_samples)

    # Convert to mono if requested and audio is stereo
    if mono and len(audio.shape) > 1:
        audio = np.mean(audio, axis=1)

    return audio

Lianmin Zheng's avatar
Lianmin Zheng committed
615

Mick's avatar
Mick committed
616
def encode_video(video_path, frame_count_limit=None):
Lianmin Zheng's avatar
Lianmin Zheng committed
617
618
619
    # Lazy import because decord is not available on some arm platforms.
    from decord import VideoReader, cpu

Mick's avatar
Mick committed
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
    if not os.path.exists(video_path):
        logger.error(f"Video {video_path} does not exist")
        return []

    if frame_count_limit == 0:
        return []

    def uniform_sample(l, n):
        gap = len(l) / n
        idxs = [int(i * gap + gap / 2) for i in range(n)]
        return [l[i] for i in idxs]

    vr = VideoReader(video_path, ctx=cpu(0))
    sample_fps = round(vr.get_avg_fps() / 1)  # FPS
    frame_indices = [i for i in range(0, len(vr), sample_fps)]
    if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
        frame_indices = uniform_sample(frame_indices, frame_count_limit)

    frames = vr.get_batch(frame_indices).asnumpy()
    frames = [Image.fromarray(v.astype("uint8")) for v in frames]
    return frames


643
def load_image(
644
    image_file: Union[Image.Image, str, bytes],
645
) -> tuple[Image.Image, tuple[int, int]]:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
646
    image = image_size = None
647
648
649
650
    if isinstance(image_file, Image.Image):
        image = image_file
        image_size = (image.width, image.height)
    elif isinstance(image_file, bytes):
651
652
        image = Image.open(BytesIO(image_file))
    elif image_file.startswith("http://") or image_file.startswith("https://"):
Lianmin Zheng's avatar
Lianmin Zheng committed
653
        timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
654
655
656
        response = requests.get(image_file, stream=True, timeout=timeout).raw
        image = Image.open(response)
        response.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
657
658
659
    elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
        image = Image.open(image_file)
    elif image_file.startswith("data:"):
660
        image_file = image_file.split(",")[1]
Lianmin Zheng's avatar
Lianmin Zheng committed
661
        image = Image.open(BytesIO(base64.b64decode(image_file)))
Yuanhan Zhang's avatar
Yuanhan Zhang committed
662
663
664
    elif image_file.startswith("video:"):
        image_file = image_file.replace("video:", "")
        image, image_size = decode_video_base64(image_file)
665
    elif isinstance(image_file, str):
Lianmin Zheng's avatar
Lianmin Zheng committed
666
        image = Image.open(BytesIO(base64.b64decode(image_file)))
667
668
    else:
        raise ValueError(f"Invalid image: {image}")
Lianmin Zheng's avatar
Lianmin Zheng committed
669

Yuanhan Zhang's avatar
Yuanhan Zhang committed
670
    return image, image_size
671
672


673
def suppress_other_loggers():
674
675
676
677
    warnings.filterwarnings(
        "ignore", category=UserWarning, message="The given NumPy array is not writable"
    )

Yineng Zhang's avatar
Yineng Zhang committed
678
679
680
681
    try:
        from vllm.logger import logger as vllm_default_logger
    except ImportError:
        return
682
683

    vllm_default_logger.setLevel(logging.WARN)
684
685
686
    logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
        logging.WARN
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
687
688
689
    logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
        logging.WARN
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
690
    logging.getLogger("vllm.config").setLevel(logging.ERROR)
691
692


693
def assert_pkg_version(pkg: str, min_version: str, message: str):
694
695
696
697
    try:
        installed_version = version(pkg)
        if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
            raise Exception(
698
                f"{pkg} is installed with version {installed_version}, which "
Ying Sheng's avatar
Ying Sheng committed
699
                f"is less than the minimum required version {min_version}. " + message
700
701
            )
    except PackageNotFoundError:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
702
        raise Exception(
Ying Sheng's avatar
Ying Sheng committed
703
704
            f"{pkg} with minimum required version {min_version} is not installed. "
            + message
Yuanhan Zhang's avatar
Yuanhan Zhang committed
705
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
706
707


708
709
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
    """Kill the process and all its child processes."""
710
711
712
713
    # Remove sigchld handler to avoid spammy logs.
    if threading.current_thread() is threading.main_thread():
        signal.signal(signal.SIGCHLD, signal.SIG_DFL)

714
715
716
    if parent_pid is None:
        parent_pid = os.getpid()
        include_parent = False
Lianmin Zheng's avatar
Lianmin Zheng committed
717

718
    try:
719
        itself = psutil.Process(parent_pid)
720
721
722
    except psutil.NoSuchProcess:
        return

Lianmin Zheng's avatar
Lianmin Zheng committed
723
    children = itself.children(recursive=True)
724
    for child in children:
725
726
        if child.pid == skip_pid:
            continue
727
728
729
730
731
        try:
            child.kill()
        except psutil.NoSuchProcess:
            pass

732
    if include_parent:
733
        try:
Lianmin Zheng's avatar
Lianmin Zheng committed
734
735
736
737
            if parent_pid == os.getpid():
                itself.kill()
                sys.exit(0)

738
            itself.kill()
739

740
741
742
743
744
            # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
            # so we send an additional signal to kill them.
            itself.send_signal(signal.SIGQUIT)
        except psutil.NoSuchProcess:
            pass
745
746


747
def monkey_patch_p2p_access_check():
748
    """
749
    Monkey patch the slow p2p access check.
750
751
752
    NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
    """

753
    import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt
Liangsheng Yin's avatar
Liangsheng Yin committed
754

755
    setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
756

Lianmin Zheng's avatar
Lianmin Zheng committed
757
    # Suppress the warnings from this delete function when using sglang.bench_one_batch
758
759
760
    from sglang.srt.distributed.device_communicators.custom_all_reduce import (
        CustomAllreduce,
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
761
762
763

    setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)

764

765
def monkey_patch_vllm_gguf_config():
Yineng Zhang's avatar
Yineng Zhang committed
766
767
768
769
770
771
772
773
    try:
        from vllm.model_executor.layers.quantization.gguf import (
            GGUFConfig,
            GGUFEmbeddingMethod,
            GGUFLinearMethod,
        )
    except ImportError:
        return
774

Yineng Zhang's avatar
Yineng Zhang committed
775
    from sglang.srt.layers.linear import LinearBase
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
    from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding

    def get_quant_method_with_embedding_replaced(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        if isinstance(layer, LinearBase):
            return GGUFLinearMethod(self)
        elif isinstance(layer, VocabParallelEmbedding):
            # patch to own VocabParallelEmbedding
            return GGUFEmbeddingMethod(self)
        return None

    setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)


791
792
793
794
795
796
def maybe_set_triton_cache_manager() -> None:
    """Set environment variable to tell Triton to use a
    custom cache manager"""
    cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
    if cache_manger is None:
        manager = "sglang.srt.utils:CustomCacheManager"
797
        logger.debug("Setting Triton cache manager to: %s", manager)
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
        os.environ["TRITON_CACHE_MANAGER"] = manager


class CustomCacheManager(FileCacheManager):
    # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
    def __init__(self, key, override=False, dump=False):

        self.key = key
        self.lock_path = None
        if dump:
            self.cache_dir = default_dump_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
            self.lock_path = os.path.join(self.cache_dir, "lock")
            os.makedirs(self.cache_dir, exist_ok=True)
        elif override:
            self.cache_dir = default_override_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
        else:
            # create cache directory if it doesn't exist
            self.cache_dir = (
                os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
            )
            if self.cache_dir:
                self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
                self.cache_dir = os.path.join(self.cache_dir, self.key)
                self.lock_path = os.path.join(self.cache_dir, "lock")
                os.makedirs(self.cache_dir, exist_ok=True)
            else:
                raise RuntimeError("Could not create or locate cache dir")


829
830
831
832
833
834
835
836
def set_ulimit(target_soft_limit=65535):
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
        except ValueError as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
837
            logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
838
839


840
def add_api_key_middleware(app, api_key: str):
841
842
843
844
845
846
    @app.middleware("http")
    async def authentication(request, call_next):
        if request.method == "OPTIONS":
            return await call_next(request)
        if request.url.path.startswith("/health"):
            return await call_next(request)
847
848
        if request.url.path.startswith("/metrics"):
            return await call_next(request)
849
        if request.headers.get("Authorization") != "Bearer " + api_key:
850
            return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
851
        return await call_next(request)
852
853


854
def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
855
    if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
856
857
858
        if not os.path.exists(model_path):
            from modelscope import snapshot_download

859
860
            model_path = snapshot_download(model_path)
            tokenizer_path = snapshot_download(
861
862
                tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
            )
863
    return model_path, tokenizer_path
864
865
866


def configure_logger(server_args, prefix: str = ""):
867
868
869
870
871
872
873
874
875
876
    if SGLANG_LOGGING_CONFIG_PATH := os.getenv("SGLANG_LOGGING_CONFIG_PATH"):
        if not os.path.exists(SGLANG_LOGGING_CONFIG_PATH):
            raise Exception(
                "Setting SGLANG_LOGGING_CONFIG_PATH from env with "
                f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
            )
        with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
            custom_config = json.loads(file.read())
        logging.config.dictConfig(custom_config)
        return
877
    format = f"[%(asctime)s{prefix}] %(message)s"
Lianmin Zheng's avatar
Lianmin Zheng committed
878
    # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
879
880
881
    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format=format,
882
        datefmt="%Y-%m-%d %H:%M:%S",
883
884
        force=True,
    )
885
886
887
888
889
890
891
892
893
894
895


# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915


def set_weight_attrs(
    weight: torch.Tensor,
    weight_attrs: Optional[Dict[str, Any]],
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
        assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
        setattr(weight, key, value)
916
917
918


def broadcast_pyobj(
919
920
921
    data: List[Any],
    rank: int,
    dist_group: Optional[torch.distributed.ProcessGroup] = None,
922
    src: int = 0,
923
    force_cpu_device: bool = True,
924
):
925
926
927
928
    """Broadcast inputs from src rank to all other ranks with torch.dist backend.
    The `rank` here refer to the source rank on global process group (regardless
    of dist_group argument).
    """
929
930
931
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
    )
932

933
    if rank == src:
934
        if len(data) == 0:
935
            tensor_size = torch.tensor([0], dtype=torch.long, device=device)
936
            dist.broadcast(tensor_size, src=src, group=dist_group)
937
938
939
        else:
            serialized_data = pickle.dumps(data)
            size = len(serialized_data)
940

941
942
            tensor_data = torch.ByteTensor(
                np.frombuffer(serialized_data, dtype=np.uint8)
943
944
            ).to(device)
            tensor_size = torch.tensor([size], dtype=torch.long, device=device)
945

946
947
            dist.broadcast(tensor_size, src=src, group=dist_group)
            dist.broadcast(tensor_data, src=src, group=dist_group)
948
949
        return data
    else:
950
        tensor_size = torch.tensor([0], dtype=torch.long, device=device)
951
        dist.broadcast(tensor_size, src=src, group=dist_group)
952
953
954
955
956
        size = tensor_size.item()

        if size == 0:
            return []

957
        tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
958
        dist.broadcast(tensor_data, src=src, group=dist_group)
959

960
        serialized_data = bytes(tensor_data.cpu().numpy())
961
962
        data = pickle.loads(serialized_data)
        return data
963
964


965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
def point_to_point_pyobj(
    data: List[Any],
    rank: int,
    group: Optional[torch.distributed.ProcessGroup] = None,
    src: int = 0,
    dst: int = 1,
):
    """Send data from src to dst in group."""

    if rank == src:
        if len(data) == 0:
            tensor_size = torch.tensor([0], dtype=torch.long)
            dist.send(tensor_size, dst=dst, group=group)
        else:
            serialized_data = pickle.dumps(data)
            size = len(serialized_data)
            tensor_data = torch.ByteTensor(
                np.frombuffer(serialized_data, dtype=np.uint8)
            )
            tensor_size = torch.tensor([size], dtype=torch.long)

            dist.send(tensor_size, dst=dst, group=group)
            dist.send(tensor_data, dst=dst, group=group)
        return data

    elif rank == dst:
        tensor_size = torch.tensor([0], dtype=torch.long)
        dist.recv(tensor_size, src=src, group=group)
        size = tensor_size.item()

        if size == 0:
            return []

        tensor_data = torch.empty(size, dtype=torch.uint8)
        dist.recv(tensor_data, src=src, group=group)

        serialized_data = bytes(tensor_data.cpu().numpy())
        data = pickle.loads(serialized_data)
        return data

    # Other ranks in pp_group do nothing
    return []


1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
step_counter = 0


def pytorch_profile(name, func, *args, data_size=-1):
    """
    Args:
        name (string): the name of recorded function.
        func: the function to be profiled.
        args: the arguments of the profiled function.
        data_size (int): some measurement of the computation complexity.
            Usually, it could be the batch size.
    """
    global step_counter
    os.makedirs("trace", exist_ok=True)
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
        # on_trace_ready=tensorboard_trace_handler('./log_dir'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as prof:
        with record_function(name):
            with open(f"trace/size_{step_counter}.json", "w") as f:
                json.dump({"size": data_size}, f)
            result = func(*args)
    prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
    step_counter += 1
    return result
1038
1039


Lianmin Zheng's avatar
Lianmin Zheng committed
1040
1041
1042
def get_zmq_socket(
    context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
):
1043
1044
1045
1046
1047
1048
1049
1050
    mem = psutil.virtual_memory()
    total_mem = mem.total / 1024**3
    available_mem = mem.available / 1024**3
    if total_mem > 32 and available_mem > 16:
        buf_size = int(0.5 * 1024**3)
    else:
        buf_size = -1

1051
    socket = context.socket(socket_type)
1052
1053
    if endpoint.find("[") != -1:
        socket.setsockopt(zmq.IPV6, 1)
1054
1055

    def set_send_opt():
1056
        socket.setsockopt(zmq.SNDHWM, 0)
1057
        socket.setsockopt(zmq.SNDBUF, buf_size)
1058
1059

    def set_recv_opt():
1060
        socket.setsockopt(zmq.RCVHWM, 0)
1061
        socket.setsockopt(zmq.RCVBUF, buf_size)
1062
1063
1064
1065
1066
1067
1068
1069

    if socket_type == zmq.PUSH:
        set_send_opt()
    elif socket_type == zmq.PULL:
        set_recv_opt()
    elif socket_type == zmq.DEALER:
        set_send_opt()
        set_recv_opt()
1070
1071
1072
    else:
        raise ValueError(f"Unsupported socket type: {socket_type}")

Lianmin Zheng's avatar
Lianmin Zheng committed
1073
1074
1075
1076
1077
    if bind:
        socket.bind(endpoint)
    else:
        socket.connect(endpoint)

1078
    return socket
1079
1080
1081


def dump_to_file(dirpath, name, value):
1082
    from sglang.srt.distributed import get_tensor_model_parallel_rank
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119

    if get_tensor_model_parallel_rank() != 0:
        return

    os.makedirs(dirpath, exist_ok=True)
    if value.dtype is torch.bfloat16:
        value = value.float()
    value = value.cpu().numpy()
    output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
    logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
    np.save(output_filename, value)


def is_triton_3():
    return triton.__version__.startswith("3.")


def maybe_torch_compile(*args, **kwargs):
    """
    torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
    Therefore, we disable it here.
    """

    def decorator(func):
        if is_triton_3():
            return torch.compile(*args, **kwargs)(func)
        return func

    return decorator


def delete_directory(dirpath):
    try:
        # This will remove the directory and all its contents
        shutil.rmtree(dirpath)
    except OSError as e:
        print(f"Warning: {dirpath} : {e.strerror}")
Lianmin Zheng's avatar
Lianmin Zheng committed
1120
1121
1122
1123
1124
1125
1126
1127
1128


# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir: tempfile.TemporaryDirectory


def set_prometheus_multiproc_dir():
    # Set prometheus multiprocess directory
1129
    # sglang uses prometheus multiprocess mode
Lianmin Zheng's avatar
Lianmin Zheng committed
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    # we need to set this before importing prometheus_client
    # https://prometheus.github.io/client_python/multiprocess/
    global prometheus_multiproc_dir

    if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
        logger.debug("User set PROMETHEUS_MULTIPROC_DIR detected.")
        prometheus_multiproc_dir = tempfile.TemporaryDirectory(
            dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
        )
    else:
        prometheus_multiproc_dir = tempfile.TemporaryDirectory()
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
    logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")


def add_prometheus_middleware(app):
1146
    # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
Lianmin Zheng's avatar
Lianmin Zheng committed
1147
1148
1149
1150
1151
1152
1153
1154
1155
    from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess

    registry = CollectorRegistry()
    multiprocess.MultiProcessCollector(registry)
    metrics_route = Mount("/metrics", make_asgi_app(registry=registry))

    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
    app.routes.append(metrics_route)
1156
1157


1158
1159
1160
1161
1162
1163
1164
1165
1166
def bind_port(port):
    """Bind to a specific port, assuming it's available."""
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)  # Allows address reuse
    sock.bind(("", port))
    sock.listen(1)
    return sock


HAI's avatar
HAI committed
1167
1168
1169
1170
def get_amdgpu_memory_capacity():
    try:
        # Run rocm-smi and capture the output
        result = subprocess.run(
1171
            [
HAI's avatar
HAI committed
1172
                "rocminfo | grep 'gfx' -A 100 | grep 'Pool 1' -A 5 | grep 'Size:' | awk '{print $2}'"
1173
            ],
HAI's avatar
HAI committed
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            shell=True,
            text=True,
        )
        if result.returncode != 0:
            raise RuntimeError(f"rocm-smi error: {result.stderr.strip()}")

        # Parse the output to extract memory values in MiB
        memory_values = [
1184
            float(mem.split("(")[0].strip()) / 1024
HAI's avatar
HAI committed
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
            for mem in result.stdout.strip().split("\n")
        ]

        if not memory_values:
            raise ValueError("No GPU memory values found.")

        # Return the minimum memory value
        return min(memory_values)

    except FileNotFoundError:
        raise RuntimeError(
            "rocm-smi not found. Ensure AMD ROCm drivers are installed and accessible."
        )


1200
1201
1202
1203
1204
1205
1206
def get_device_sm():
    if torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability()
        return major * 10 + minor
    return 0


HAI's avatar
HAI committed
1207
def get_nvgpu_memory_capacity():
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
    try:
        # Run nvidia-smi and capture the output
        result = subprocess.run(
            ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )

        if result.returncode != 0:
            raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}")

        # Parse the output to extract memory values
        memory_values = [
            float(mem)
            for mem in result.stdout.strip().split("\n")
            if re.match(r"^\d+(\.\d+)?$", mem.strip())
        ]

        if not memory_values:
            raise ValueError("No GPU memory values found.")

        # Return the minimum memory value
        return min(memory_values)

    except FileNotFoundError:
        raise RuntimeError(
            "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
        )
1237
1238


1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
def get_hpu_memory_capacity():
    try:
        # Run hl-smi and capture the output
        result = subprocess.run(
            ["hl-smi --query | grep 'Total'"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            shell=True,
            text=True,
        )

        if result.returncode != 0:
            raise RuntimeError(f"hl-smi error: {result.stderr.strip()}")

        # Parse the output to extract memory values in MiB
        memory_values = [
            float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n")
        ]

        if not memory_values:
            raise ValueError("No GPU memory values found.")

        # Return the minimum memory value
        return min(memory_values)

    except FileNotFoundError:
        raise RuntimeError(
            "hl-smi not found. Ensure Habana drivers are installed and accessible."
        )


Lianmin Zheng's avatar
Lianmin Zheng committed
1270
def get_device_memory_capacity(device: str = None):
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
    if is_cuda():
        gpu_mem = get_nvgpu_memory_capacity()
    elif is_hip():
        gpu_mem = get_amdgpu_memory_capacity()
    elif device == "hpu":
        gpu_mem = get_hpu_memory_capacity()
    else:
        # GPU memory is not known yet or no GPU is available.
        gpu_mem = None

    return gpu_mem


1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
def init_custom_process_group(
    backend=None,
    init_method=None,
    timeout=None,
    world_size=-1,
    rank=-1,
    store=None,
    group_name=None,
    pg_options=None,
):
    from torch.distributed.distributed_c10d import (
        Backend,
        PrefixStore,
        _new_process_group_helper,
        _world,
        default_pg_timeout,
        rendezvous,
    )

    assert (store is None) or (
        init_method is None
    ), "Cannot specify both init_method and store."

    if store is not None:
        assert world_size > 0, "world_size must be positive if using store"
        assert rank >= 0, "rank must be non-negative if using store"
    elif init_method is None:
        init_method = "env://"

    if backend:
        backend = Backend(backend)
    else:
        backend = Backend("undefined")

    if timeout is None:
        timeout = default_pg_timeout

    # backward compatible API
    if store is None:
        rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
        store, rank, world_size = next(rendezvous_iterator)
        store.set_timeout(timeout)

        # Use a PrefixStore to avoid accidental overrides of keys used by
        # different systems (e.g. RPC) in case the store is multi-tenant.
        store = PrefixStore(group_name, store)

    # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
    # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
    # We need to determine the appropriate parameter name based on PyTorch version
    pg_options_param_name = (
        "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
    )
    pg, _ = _new_process_group_helper(
        world_size,
        rank,
        [],
        backend,
        store,
        group_name=group_name,
        **{pg_options_param_name: pg_options},
        timeout=timeout,
    )

    _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}

    return pg


1356
1357
def crash_on_warnings():
    # Crash on warning if we are running CI tests
1358
    return get_bool_env_var("SGLANG_IS_IN_CI")
1359
1360


1361
1362
1363
1364
1365
def print_warning_once(msg: str) -> None:
    # Set the stacklevel to 2 to print the caller's line info
    logger.warning(msg, stacklevel=2)


1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
def get_device_name(device_id: int = 0) -> str:
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        return torch.cuda.get_device_name(device_id)

    if hasattr(torch, "xpu") and torch.xpu.is_available():
        return torch.xpu.get_device_name(device_id)

    if hasattr(torch, "hpu") and torch.hpu.is_available():
        return torch.hpu.get_device_name(device_id)

1376
1377
1378
    if hasattr(torch, "npu") and torch.npu.is_available():
        return torch.npu.get_device_name(device_id)

1379

1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
@lru_cache(maxsize=1)
def is_habana_available() -> bool:
    return find_spec("habana_frameworks") is not None


@lru_cache(maxsize=8)
def get_device(device_id: Optional[int] = None) -> str:
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        if device_id is None:
            return "cuda"
        return "cuda:{}".format(device_id)

    if hasattr(torch, "xpu") and torch.xpu.is_available():
        if device_id == None:
            return "xpu"
        return "xpu:{}".format(device_id)

    if is_habana_available():
        try:
            import habana_frameworks.torch.hpu

            if torch.hpu.is_available():
                if device_id == None:
                    return "hpu"
                return "hpu:{}".format(device_id)
        except ImportError as e:
            raise ImportError(
                "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
            )

    raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")


@lru_cache(maxsize=1)
def get_device_count() -> int:
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        try:
            return torch.cuda.device_count()
        except RuntimeError:
            return 0

    if hasattr(torch, "xpu") and torch.xpu.is_available():
        try:
            return torch.xpu.device_count()
        except RuntimeError:
            return 0

    if is_habana_available():
        try:
            import habana_frameworks.torch.hpu

            if torch.hpu.is_available():
                return torch.hpu.device_count()
        except (ImportError, RuntimeError):
            return 0

    return 0  # No accelerators available


1439
1440
1441
1442
1443
1444
1445
def get_device_core_count(device_id: int = 0) -> int:
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        return torch.cuda.get_device_properties(device_id).multi_processor_count

    return 0


1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
    major, minor = None, None
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability(device_id)

    if hasattr(torch, "xpu") and torch.xpu.is_available():
        major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
            "."
        )
        major, minor = int(major), int(minor)

    if hasattr(torch, "hpu") and torch.hpu.is_available():
        try:
1459
1460
1461
1462
            # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
            # Update this once the support is available.
            # major, minor = torch.hpu.get_device_capability(device_id)
            major, minor = None, None
1463
1464
1465
1466
1467
1468
1469
1470
        except Exception as e:
            raise RuntimeError(
                f"An error occurred while getting device capability of hpu: {e}."
            ) from e

    return major, minor


1471
1472
1473
1474
def get_compiler_backend() -> str:
    if hasattr(torch, "hpu") and torch.hpu.is_available():
        return "hpu_backend"

1475
1476
1477
1478
1479
1480
1481
    if hasattr(torch, "npu") and torch.npu.is_available():
        import torchair

        config = torchair.CompilerConfig()
        npu_backend = torchair.get_npu_backend(compiler_config=config)
        return npu_backend

1482
1483
1484
    return "inductor"


1485
1486
1487
sglang_lib = Library("sglang", "FRAGMENT")  # noqa


1488
1489
1490
1491
1492
1493
# Some backends use pytorch version < 2.4.0 which doesn't
# support `torch.library.custom_op`.
def supports_custom_op() -> bool:
    return hasattr(torch.library, "custom_op")


1494
1495
1496
1497
1498
1499
1500
def direct_register_custom_op(
    op_name: str,
    op_func: Callable,
    mutates_args: List[str],
    fake_impl: Optional[Callable] = None,
    target_lib: Optional[Library] = None,
):
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
    """
    `torch.library.custom_op` can have significant overhead because it
    needs to consider complicated dispatching logic. This function
    directly registers a custom op and dispatches it to the CUDA backend.
    See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
    for more details.

    By default, the custom op is registered to the vLLM library. If you
    want to register it to a different library, you can pass the library
    object to the `target_lib` argument.

    IMPORTANT: the lifetime of the operator is tied to the lifetime of the
    library object. If you want to bind the operator to a different library,
    make sure the library object is alive when the operator is used.
    """
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
    import torch.library

    if hasattr(torch.library, "infer_schema"):
        schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
    else:
        # for pytorch 2.4
        import torch._custom_op.impl

        schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)

    my_lib = target_lib or sglang_lib
    my_lib.define(op_name + schema_str)
    my_lib.impl(op_name, op_func, "CUDA")
    if fake_impl is not None:
        my_lib._register_fake(op_name, fake_impl)
1531
1532


1533
def set_gpu_proc_affinity(
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
    tp_size: int,
    nnodes: int,
    gpu_id: int,
):
    # current process
    pid = os.getpid()
    p = psutil.Process(pid)

    tp_size_per_node = tp_size // nnodes

    # total physical cores
    total_pcores = psutil.cpu_count(logical=False)
    # physical cores per TP (N.B. more Cores than GPUs on node)
    num_cores_bind = total_pcores // tp_size_per_node

    # able to handle multiple DP per node
    start_cpu_id = (gpu_id * num_cores_bind) % total_pcores
    end_cpu_id = start_cpu_id + num_cores_bind

    if psutil.cpu_count() != psutil.cpu_count(logical=False):
        # HT on
Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
1555
1556
1557
        lower_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
        upper_cpu_ids = [id + total_pcores for id in range(start_cpu_id, end_cpu_id)]
        bind_cpu_ids = list(itertools.chain(lower_cpu_ids, upper_cpu_ids))
1558
1559
1560
1561
1562
1563
1564
    else:
        # HT off
        bind_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]

    # set cpu_affinity to current process
    p.cpu_affinity(bind_cpu_ids)
    logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")
1565
1566


1567
1568
1569
1570
1571
@lru_cache(maxsize=2)
def disable_request_logging() -> bool:
    return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")


1572
1573
1574
1575
1576
def dataclass_to_string_truncated(
    data, max_length=2048, skip_names: Optional[Set[str]] = None
):
    if skip_names is None:
        skip_names = set()
1577
1578
1579
    if isinstance(data, str):
        if len(data) > max_length:
            half_length = max_length // 2
1580
            return f"{repr(data[:half_length])} ... {repr(data[-half_length:])}"
1581
        else:
1582
            return f"{repr(data)}"
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
    elif isinstance(data, (list, tuple)):
        if len(data) > max_length:
            half_length = max_length // 2
            return str(data[:half_length]) + " ... " + str(data[-half_length:])
        else:
            return str(data)
    elif isinstance(data, dict):
        return (
            "{"
            + ", ".join(
1593
                f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
1594
                for k, v in data.items()
1595
                if k not in skip_names
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
            )
            + "}"
        )
    elif dataclasses.is_dataclass(data):
        fields = dataclasses.fields(data)
        return (
            f"{data.__class__.__name__}("
            + ", ".join(
                f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
                for f in fields
1606
                if f.name not in skip_names
1607
1608
1609
            )
            + ")"
        )
1610
    else:
1611
        return str(data)
Tanjiro's avatar
Tanjiro committed
1612
1613


1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
def permute_weight(x: torch.Tensor) -> torch.Tensor:
    b_ = x.shape[0]
    n_ = x.shape[1]
    k_ = x.shape[2]

    x_ = x
    if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
        x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
    elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
        x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
    else:
1625
1626
        # return x_
        x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 8), 2, 4)
1627
1628
1629
1630
1631
1632
1633

    x_ = x_.permute(0, 1, 3, 4, 2, 5)
    x_ = x_.contiguous()
    x_ = x_.view(*x.shape)
    return x_


1634
1635
class MultiprocessingSerializer:
    @staticmethod
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
    def serialize(obj, output_str: bool = False):
        """
        Serialize a Python object using ForkingPickler.

        Args:
            obj: The object to serialize.
            output_str (bool): If True, return a base64-encoded string instead of raw bytes.

        Returns:
            bytes or str: The serialized object.
        """
1647
1648
1649
        buf = io.BytesIO()
        ForkingPickler(buf).dump(obj)
        buf.seek(0)
1650
1651
1652
1653
1654
1655
1656
        output = buf.read()

        if output_str:
            # Convert bytes to base64-encoded string
            output = base64.b64encode(output).decode("utf-8")

        return output
1657
1658
1659

    @staticmethod
    def deserialize(data):
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
        """
        Deserialize a previously serialized object.

        Args:
            data (bytes or str): The serialized data, optionally base64-encoded.

        Returns:
            The deserialized Python object.
        """
        if isinstance(data, str):
            # Decode base64 string to bytes
            data = base64.b64decode(data)

1673
        return ForkingPickler.loads(data)
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684


def debug_timing(func):
    # todo: replace with a more organized instrumentation
    def wrapper(*args, **kwargs):
        if logger.isEnabledFor(logging.DEBUG):
            tic = torch.cuda.Event(enable_timing=True)
            toc = torch.cuda.Event(enable_timing=True)
            tic.record()
            result = func(*args, **kwargs)
            toc.record()
1685
            toc.synchronize()  # Wait for the function to complete without synchronizing all ops on the GPU
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
            elapsed = tic.elapsed_time(toc)
            indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
            num_tokens = len(indices) if indices is not None else 0
            throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0
            logger.debug(
                f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s"
            )
            return result
        else:
            return func(*args, **kwargs)

    return wrapper
bjmsong's avatar
bjmsong committed
1698
1699
1700
1701
1702
1703


def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val
1704
1705


1706
1707
1708
1709
1710
1711
1712
1713
1714
def pyspy_dump_schedulers():
    """py-spy dump on all scheduler in a local node."""
    try:
        pid = psutil.Process().pid
        # Command to run py-spy with the PID
        cmd = f"py-spy dump --pid {pid}"
        result = subprocess.run(
            cmd, shell=True, capture_output=True, text=True, check=True
        )
1715
        logger.error(f"Pyspy dump for PID {pid}:\n{result.stdout}")
1716
    except subprocess.CalledProcessError as e:
1717
        logger.error(f"Pyspy failed to dump PID {pid}. Error: {e.stderr}")
1718
1719
1720
1721
1722
1723
1724
1725
1726


def kill_itself_when_parent_died():
    if sys.platform == "linux":
        # sigkill this process when parent worker manager dies
        PR_SET_PDEATHSIG = 1
        libc = ctypes.CDLL("libc.so.6")
        libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)
    else:
Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
1727
        logger.warning("kill_itself_when_parent_died is only supported in linux.")
1728
1729


1730
def set_uvicorn_logging_configs():
1731
1732
    from uvicorn.config import LOGGING_CONFIG

1733
1734
1735
1736
1737
1738
1739
1740
    LOGGING_CONFIG["formatters"]["default"][
        "fmt"
    ] = "[%(asctime)s] %(levelprefix)s %(message)s"
    LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
    LOGGING_CONFIG["formatters"]["access"][
        "fmt"
    ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
    LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780


def get_ip() -> str:
    # SGLANG_HOST_IP env can be ignore
    host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

    # try ipv4
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
        pass

    # try ipv6
    try:
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
        "The value can be set by the environment variable"
        " SGLANG_HOST_IP or HOST_IP.",
        stacklevel=2,
    )
    return "0.0.0.0"


def get_open_port() -> int:
    port = os.getenv("SGLANG_PORT")
    if port is not None:
1781
        port = int(port)
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
        while True:
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("", port))
                    return port
            except OSError:
                port += 1  # Increment port number if already in use
                logger.info("Port %d is already in use, trying port %d", port - 1, port)
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]


def is_valid_ipv6_address(address: str) -> bool:
    try:
        ipaddress.IPv6Address(address)
        return True
    except ValueError:
        return False
1808
1809


Vincent's avatar
Vincent committed
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
def configure_ipv6(dist_init_addr):
    addr = dist_init_addr
    end = addr.find("]")
    if end == -1:
        raise ValueError("invalid IPv6 address format: missing ']'")

    host = addr[: end + 1]

    # this only validates the address without brackets: we still need the below checks.
    # if it's invalid, immediately raise an error so we know it's not formatting issues.
    if not is_valid_ipv6_address(host[1:end]):
        raise ValueError(f"invalid IPv6 address: {host}")

    port_str = None
    if len(addr) > end + 1:
        if addr[end + 1] == ":":
            port_str = addr[end + 2 :]
        else:
            raise ValueError("received IPv6 address format: expected ':' after ']'")

    if not port_str:
        raise ValueError(
            "a port must be specified in IPv6 address (format: [ipv6]:port)"
        )

    try:
        port = int(port_str)
    except ValueError:
        raise ValueError(f"invalid port in IPv6 address: '{port_str}'")
    return port, host


1842
1843
1844
1845
1846
1847
1848
def rank0_log(msg: str):
    from sglang.srt.distributed import get_tensor_model_parallel_rank

    if get_tensor_model_parallel_rank() == 0:
        logger.info(msg)


1849
1850
1851
1852
1853
def rank0_print(msg: str):
    from sglang.srt.distributed import get_tensor_model_parallel_rank

    if get_tensor_model_parallel_rank() == 0:
        print(msg, flush=True)
1854
1855


HandH1998's avatar
HandH1998 committed
1856
1857
1858
1859
1860
1861
def get_cuda_version():
    if torch.version.cuda:
        return tuple(map(int, torch.version.cuda.split(".")))
    return (0, 0)


1862
def launch_dummy_health_check_server(host, port):
1863
1864
    import asyncio

1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
    import uvicorn
    from fastapi import FastAPI, Response

    app = FastAPI()

    @app.get("/health")
    async def health():
        """Check the health of the http server."""
        return Response(status_code=200)

    @app.get("/health_generate")
    async def health_generate():
        """Check the health of the http server."""
        return Response(status_code=200)

1880
    config = uvicorn.Config(
1881
1882
1883
1884
        app,
        host=host,
        port=port,
        timeout_keep_alive=5,
1885
1886
1887
        loop="auto",
        log_config=None,
        log_level="warning",
1888
    )
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
    server = uvicorn.Server(config=config)

    try:
        loop = asyncio.get_running_loop()
        logger.info(
            f"Dummy health check server scheduled on existing loop at {host}:{port}"
        )
        loop.create_task(server.serve())

    except RuntimeError:
        logger.info(f"Starting dummy health check server at {host}:{port}")
        server.run()
1901
1902


1903
1904
1905
1906
def create_checksum(directory: str):
    raise NotImplementedError()


1907
1908
1909
1910
1911
def set_cuda_arch():
    if is_flashinfer_available():
        capability = torch.cuda.get_device_capability()
        arch = f"{capability[0]}.{capability[1]}"
        os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
1912
1913


Lianmin Zheng's avatar
Lianmin Zheng committed
1914
1915
1916
1917
1918
1919
1920
def next_power_of_2(n: int):
    return 1 << (n - 1).bit_length() if n > 0 else 1


setattr(triton, "next_power_of_2", next_power_of_2)


1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
@contextmanager
def empty_context(*args, **kwargs):
    try:
        # Setup code goes here
        yield
    finally:
        # Cleanup code goes here
        pass


1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
def add_prefix(name: str, prefix: str) -> str:
    """Add a weight path prefix to a module name.

    Args:
        name: base module name.
        prefix: weight prefix str to added to the front of `name` concatenated with `.`.

    Returns:
        The string `prefix.name` if prefix is non-empty, otherwise just `name`.
    """
    return name if not prefix else f"{prefix}.{name}"
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967


def is_remote_url(url: Union[str, Path]) -> bool:
    """
    Check if the URL is a remote URL of the format:
    <connector_type>://<host>:<port>/<model_name>
    """
    if isinstance(url, Path):
        return False

    pattern = r"(.+)://(.*)"
    m = re.match(pattern, url)
    return m is not None


def parse_connector_type(url: str) -> str:
    """
    Parse the connector type from the URL of the format:
    <connector_type>://<path>
    """
    pattern = r"(.+)://(.*)"
    m = re.match(pattern, url)
    if m is None:
        return ""

    return m.group(1)
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996


def retry(
    fn,
    max_retry: int,
    initial_delay: float = 2.0,
    max_delay: float = 60.0,
    should_retry: Callable[[Any], bool] = lambda e: True,
):
    for try_index in itertools.count():
        try:
            return fn()
        except Exception as e:
            if try_index >= max_retry:
                raise Exception(f"retry() exceed maximum number of retries.")

            if not should_retry(e):
                raise Exception(f"retry() observe errors that should not be retried.")

            delay = min(initial_delay * (2**try_index), max_delay) * (
                0.75 + 0.25 * random.random()
            )

            logger.warning(
                f"retry() failed once ({try_index}th try, maximum {max_retry} retries). Will delay {delay:.2f}s and retry. Error: {e}"
            )
            traceback.print_exc()

            time.sleep(delay)
Mick's avatar
Mick committed
1997
1998
1999
2000
2001
2002
2003
2004
2005


def flatten_nested_list(nested_list):
    if isinstance(nested_list, list):
        return [
            item for sublist in nested_list for item in flatten_nested_list(sublist)
        ]
    else:
        return [nested_list]
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026


class DeepEPMode(Enum):
    normal = "normal"
    low_latency = "low_latency"
    auto = "auto"

    def enable_normal(self):
        return self in [DeepEPMode.normal, DeepEPMode.auto]

    def enable_low_latency(self):
        return self in [DeepEPMode.low_latency, DeepEPMode.auto]

    def resolve(self, forward_mode):
        if self != DeepEPMode.auto:
            return self

        if forward_mode.is_decode():
            return DeepEPMode.low_latency
        else:
            return DeepEPMode.normal
2027
2028


2029
2030
2031
2032
2033
2034
2035
2036
def is_non_idle_and_non_empty(forward_mode, hidden_states):
    return (
        (forward_mode is not None)
        and not forward_mode.is_idle()
        and hidden_states.shape[0] > 0
    )


2037
2038
2039
2040
2041
2042
2043
def fast_topk(values, topk, dim):
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)
2044
2045


2046
def _check(cc_major):
2047
2048
    if not is_cuda():
        return False
2049
2050
2051
2052
2053
2054
2055
    return torch.cuda.get_device_capability()[0] == cc_major and tuple(
        map(int, torch.version.cuda.split(".")[:2])
    ) >= (12, 3)


is_ampere_with_cuda_12_3 = lambda: _check(8)
is_hopper_with_cuda_12_3 = lambda: _check(9)
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087


def get_free_port():
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]


def get_local_ip_by_remote() -> str:
    # try ipv4
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
        pass

    # try ipv6
    try:
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
Lianmin Zheng's avatar
Lianmin Zheng committed
2088
        raise ValueError("Can not get local ip")
2089
2090
2091
2092
2093
2094


def is_page_size_one(server_args):
    return server_args.page_size == 1


2095
2096
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
def is_no_spec_infer_or_topk_one(server_args):
    return server_args.speculative_eagle_topk is None or (
        server_args.speculative_eagle_topk is not None
        and server_args.speculative_eagle_topk == 1
        and is_page_size_one(server_args)
    )


def is_fa3_default_architecture(hf_config):
    architectures = getattr(hf_config, "architectures", None)
    if not isinstance(architectures, list) or not architectures:
        return False
    default_archs = {
        "Qwen2ForCausalLM",
        "Llama4ForConditionalGeneration",
        "LlamaForCausalLM",
Yineng Zhang's avatar
Yineng Zhang committed
2113
        "Gemma2ForCausalLM",
2114
        "Gemma3ForConditionalGeneration",
2115
2116
        "Qwen3ForCausalLM",
        "Qwen3MoeForCausalLM",
2117
2118
    }
    return architectures[0] in default_archs
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131


# Can be more general if it is used in multiple places (keep it simple and thus not general now)
class BumpAllocator:
    def __init__(self, buffer_size: int, dtype, device):
        self._buffer = torch.zeros((buffer_size,), dtype=dtype, device=device)
        self._pointer = 0

    def allocate(self, size: int):
        assert self._pointer + size <= len(self._buffer)
        output = self._buffer[self._pointer : self._pointer + size]
        self._pointer += size
        return output
2132
2133
2134
2135
2136
2137
2138


def log_info_on_rank0(logger, msg):
    from sglang.srt.distributed import get_tensor_model_parallel_rank

    if get_tensor_model_parallel_rank() == 0:
        logger.info(msg)
fzyzcjy's avatar
fzyzcjy committed
2139
2140


2141
2142
2143
2144
2145
2146
2147
def load_json_config(data: str):
    try:
        return json.loads(data)
    except JSONDecodeError:
        return json.loads(Path(data).read_text())


fzyzcjy's avatar
fzyzcjy committed
2148
2149
def dispose_tensor(x: torch.Tensor):
    x.set_(torch.empty((0,), device=x.device, dtype=x.dtype))
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171


T = TypeVar("T")


class Withable(Generic[T]):
    def __init__(self):
        self._value: Optional[T] = None

    @property
    def value(self) -> T:
        return self._value

    @contextmanager
    def with_value(self, new_value: T):
        assert self._value is None
        self._value = new_value
        try:
            yield
        finally:
            assert self._value is new_value
            self._value = None