utils.py 63.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""Common utilities."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import base64
17
import builtins
18
import ctypes
19
import dataclasses
20
import importlib
21
import io
22
import ipaddress
23
import itertools
24
import json
25
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
26
import os
27
import pickle
Lianmin Zheng's avatar
Lianmin Zheng committed
28
import random
Lianmin Zheng's avatar
Lianmin Zheng committed
29
import re
30
import resource
31
32
import shutil
import signal
Lianmin Zheng's avatar
Lianmin Zheng committed
33
import socket
34
import subprocess
35
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
36
import tempfile
37
import threading
Lianmin Zheng's avatar
Lianmin Zheng committed
38
import time
39
import traceback
40
import warnings
41
from contextlib import contextmanager
42
from enum import Enum
43
from functools import lru_cache
44
from importlib.metadata import PackageNotFoundError, version
45
from importlib.util import find_spec
Lianmin Zheng's avatar
Lianmin Zheng committed
46
from io import BytesIO
47
from multiprocessing.reduction import ForkingPickler
48
from pathlib import Path
49
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
50
51

import numpy as np
52
import psutil
Lianmin Zheng's avatar
Lianmin Zheng committed
53
54
import requests
import torch
55
import torch.distributed
56
import torch.distributed as dist
57
import triton
58
import zmq
59
from fastapi.responses import ORJSONResponse
60
from packaging import version as pkg_version
Mick's avatar
Mick committed
61
from PIL import Image
Lianmin Zheng's avatar
Lianmin Zheng committed
62
from starlette.routing import Mount
63
from torch import nn
64
from torch.func import functional_call
65
from torch.library import Library
66
from torch.profiler import ProfilerActivity, profile, record_function
67
from torch.utils._contextlib import _DecoratorContextManager
68
69
70
71
72
73
from triton.runtime.cache import (
    FileCacheManager,
    default_cache_dir,
    default_dump_dir,
    default_override_dir,
)
74

75
76
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
77
78
show_time_cost = False
time_infos = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
79

80
81
HIP_FP8_E4M3_FNUZ_MAX = 224.0

82
83
_warned_bool_env_var_keys = set()

Lianmin Zheng's avatar
Lianmin Zheng committed
84

85
86
def get_bool_env_var(name: str, default: str = "false") -> bool:
    value = os.getenv(name, default)
87
88
89
90
91
92
93
94
95
96
97
98
99
    value = value.lower()

    truthy_values = ("true", "1")
    falsy_values = ("false", "0")

    if (value not in truthy_values) and (value not in falsy_values):
        if value not in _warned_bool_env_var_keys:
            logger.warning(
                f"get_bool_env_var({name}) see non-understandable value={value} and treat as false"
            )
        _warned_bool_env_var_keys.add(value)

    return value in truthy_values
100
101


102
103
104
105
106
107
108
109
110
111
def get_int_env_var(name: str, default: int = 0) -> int:
    value = os.getenv(name)
    if value is None or not value.strip():
        return default
    try:
        return int(value)
    except ValueError:
        return default


112
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
113
114
115
116
def is_hip() -> bool:
    return torch.version.hip is not None


117
118
119
120
121
122
123
124
125
126
127
if is_hip():
    FP8_E4M3_MAX = HIP_FP8_E4M3_FNUZ_MAX
else:
    FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max

FP8_E4M3_MIN = -FP8_E4M3_MAX

builtins.FP8_E4M3_MAX = FP8_E4M3_MAX
builtins.FP8_E4M3_MIN = FP8_E4M3_MIN


128
129
130
131
def is_rocm() -> bool:
    return torch.cuda.is_available() and torch.version.hip


132
def is_cuda():
133
    return torch.cuda.is_available() and torch.version.cuda
134
135
136
137
138
139
140
141
142
143
144
145
146
147


def is_cuda_alike():
    return is_cuda() or is_hip()


def is_hpu() -> bool:
    return hasattr(torch, "hpu") and torch.hpu.is_available()


def is_xpu() -> bool:
    return hasattr(torch, "xpu") and torch.xpu.is_available()


148
149
150
151
152
def is_flashinfer_available():
    """
    Check whether flashinfer is available.
    As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
    """
153
    if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
154
        return False
155
    return importlib.util.find_spec("flashinfer") is not None and is_cuda()
156
157


158
_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
159
    "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
160
)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214


class DynamicGradMode(_DecoratorContextManager):
    """
    A combination of torch.no_grad and torch.inference_mode,
    with their behavior controlled by an environment variable. Just refer to them.
    """

    @staticmethod
    def set_inference_mode(mode: bool):
        if isinstance(mode, bool):
            global _ENABLE_TORCH_INFERENCE_MODE

            _ENABLE_TORCH_INFERENCE_MODE = mode
        else:
            logger.warning("mode is not a boolean object")

    def __init__(self, mode=True):
        if not torch._jit_internal.is_scripting():
            super().__init__()
        if _ENABLE_TORCH_INFERENCE_MODE:
            self.mode = mode
        else:
            self.prev = False

    def __new__(cls, mode_or_orig_func=True if _ENABLE_TORCH_INFERENCE_MODE else None):
        if mode_or_orig_func is None or isinstance(mode_or_orig_func, bool):
            return super().__new__(cls)
        return cls()(mode_or_orig_func)

    def __enter__(self) -> None:
        if _ENABLE_TORCH_INFERENCE_MODE:
            self._inference_mode_context = torch._C._InferenceMode(self.mode)
            self._inference_mode_context.__enter__()
        else:
            self.prev = torch.is_grad_enabled()
            torch.set_grad_enabled(False)

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        if _ENABLE_TORCH_INFERENCE_MODE:
            self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
        else:
            torch.set_grad_enabled(self.prev)

    def clone(self) -> "DynamicGradMode":
        r"""
        Create a copy of this class
        """
        if _ENABLE_TORCH_INFERENCE_MODE:
            return self.__class__(self.mode)
        else:
            return self.__class__()


Liangsheng Yin's avatar
Liangsheng Yin committed
215
216
217
218
def enable_show_time_cost():
    global show_time_cost
    show_time_cost = True

Lianmin Zheng's avatar
Lianmin Zheng committed
219

Liangsheng Yin's avatar
Liangsheng Yin committed
220
221
222
223
224
225
class TimeInfo:
    def __init__(self, name, interval=0.1, color=0, indent=0):
        self.name = name
        self.interval = interval
        self.color = color
        self.indent = indent
Lianmin Zheng's avatar
Lianmin Zheng committed
226

Liangsheng Yin's avatar
Liangsheng Yin committed
227
228
        self.acc_time = 0
        self.last_acc_time = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
229

Liangsheng Yin's avatar
Liangsheng Yin committed
230
231
232
233
234
    def check(self):
        if self.acc_time - self.last_acc_time > self.interval:
            self.last_acc_time = self.acc_time
            return True
        return False
Lianmin Zheng's avatar
Lianmin Zheng committed
235

Liangsheng Yin's avatar
Liangsheng Yin committed
236
237
238
239
    def pretty_print(self):
        print(f"\x1b[{self.color}m", end="")
        print("-" * self.indent * 2, end="")
        print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
Lianmin Zheng's avatar
Lianmin Zheng committed
240
241


Liangsheng Yin's avatar
Liangsheng Yin committed
242
243
244
245
def mark_start(name, interval=0.1, color=0, indent=0):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
246
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
247
248
249
    if time_infos.get(name, None) is None:
        time_infos[name] = TimeInfo(name, interval, color, indent)
    time_infos[name].acc_time -= time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
250
251


Liangsheng Yin's avatar
Liangsheng Yin committed
252
253
254
255
def mark_end(name):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
256
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
257
258
259
    time_infos[name].acc_time += time.time()
    if time_infos[name].check():
        time_infos[name].pretty_print()
Lianmin Zheng's avatar
Lianmin Zheng committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280


def calculate_time(show=False, min_cost_ms=0.0):
    def wrapper(func):
        def inner_func(*args, **kwargs):
            torch.cuda.synchronize()
            if show:
                start_time = time.time()
            result = func(*args, **kwargs)
            torch.cuda.synchronize()
            if show:
                cost_time = (time.time() - start_time) * 1000
                if cost_time > min_cost_ms:
                    print(f"Function {func.__name__} took {cost_time} ms to run.")
            return result

        return inner_func

    return wrapper


281
def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
Lianmin Zheng's avatar
Lianmin Zheng committed
282
283
284
285
    """
    Get available memory for cuda:gpu_id device.
    When distributed is True, the available memory is the minimum available memory of all GPUs.
    """
Zhang, Liangang's avatar
Zhang, Liangang committed
286
    if device == "cuda":
287
        num_gpus = torch.cuda.device_count()
Zhang, Liangang's avatar
Zhang, Liangang committed
288
289
290
291
292
293
294
295
        assert gpu_id < num_gpus

        if torch.cuda.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
                "which may cause useless memory allocation for torch CUDA context.",
            )

296
297
        if empty_cache:
            torch.cuda.empty_cache()
Zhang, Liangang's avatar
Zhang, Liangang committed
298
299
300
301
302
303
304
305
306
307
308
        free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)

    elif device == "xpu":
        num_gpus = torch.xpu.device_count()
        assert gpu_id < num_gpus

        if torch.xpu.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
                "which may cause useless memory allocation for torch XPU context.",
            )
309
310
311

        if empty_cache:
            torch.xpu.empty_cache()
Zhang, Liangang's avatar
Zhang, Liangang committed
312
313
314
        used_memory = torch.xpu.memory_allocated()
        total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
        free_gpu_memory = total_gpu_memory - used_memory
Lianmin Zheng's avatar
Lianmin Zheng committed
315

316
317
318
319
320
321
322
323
324
325
326
327
    elif device == "hpu":
        num_gpus = torch.hpu.device_count()
        assert gpu_id < num_gpus

        if torch.hpu.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ",
                "which may cause useless memory allocation for torch HPU context.",
            )

        free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()

328
329
330
331
    elif device == "cpu":
        # TODO: rename the variables in the current function to be not GPU specific
        free_gpu_memory = psutil.virtual_memory().available

Lianmin Zheng's avatar
Lianmin Zheng committed
332
333
    if distributed:
        tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
Zhang, Liangang's avatar
Zhang, Liangang committed
334
            torch.device(device, gpu_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
335
336
337
338
339
340
341
        )
        torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
        free_gpu_memory = tensor.item()

    return free_gpu_memory / (1 << 30)


342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def is_pin_memory_available() -> bool:
    return torch.cuda.is_available()


_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    _CPU_OFFLOAD_BYTES = 0
    _CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
    device = next(module.parameters()).device

    if device == torch.device("cpu"):
        return module

    global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
    if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
        return module

    pin_memory = is_pin_memory_available()
    # offload parameters to CPU
    # use pin_memory if possible, which helps cudagraph capture speed
    offloaded_parameters = False
    for p in module.parameters():
        if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
            # we use per-parameter offloading
            # one module might have some parameters offloaded and some not
            break

        # `torch.empty_like` does not support `pin_memory` argument
        cpu_data = torch.empty_strided(
            size=p.data.size(),
            stride=p.data.stride(),
            dtype=p.data.dtype,
            layout=p.data.layout,
            device="cpu",
            pin_memory=pin_memory,
        )
        cpu_data.copy_(p.data)
        p.data = cpu_data
        _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
        offloaded_parameters = True

    if offloaded_parameters:
        original_forward = module.forward

        def forward(*args, **kwargs):
            module.forward = original_forward
            device_state = {
                # here we blindly call `to(device)`
                # if the parameter is already on the device, it will be a no-op
                k: v.to(device, non_blocking=True)
                for k, v in module.state_dict().items()
            }
            output = functional_call(module, device_state, args=args, kwargs=kwargs)
            module.forward = forward
            return output

        module.forward = forward

    return module


class LayerFn(Protocol):

    def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...


def make_layers(
    num_hidden_layers: int,
    layer_fn: LayerFn,
418
419
    pp_rank: Optional[int] = None,
    pp_size: Optional[int] = None,
420
    prefix: str = "",
421
    return_tuple: bool = False,
422
423
) -> Tuple[int, int, torch.nn.ModuleList]:
    """Make a list of layers with the given layer function"""
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    # circula imports
    from sglang.srt.distributed import get_pp_indices
    from sglang.srt.layers.utils import PPMissingLayer

    assert not pp_size or num_hidden_layers >= pp_size
    start_layer, end_layer = (
        get_pp_indices(
            num_hidden_layers,
            pp_rank,
            pp_size,
        )
        if pp_rank is not None and pp_size is not None
        else (0, num_hidden_layers)
    )
438
    modules = torch.nn.ModuleList(
439
440
        [PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)]
        + [
441
            maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
442
443
444
445
446
            for idx in range(start_layer, end_layer)
        ]
        + [
            PPMissingLayer(return_tuple=return_tuple)
            for _ in range(end_layer, num_hidden_layers)
447
448
        ]
    )
449
450
451
    if pp_rank is None or pp_size is None:
        return modules
    return modules, start_layer, end_layer
452
453


Lianmin Zheng's avatar
Lianmin Zheng committed
454
def set_random_seed(seed: int) -> None:
455
    """Set the random seed for all libraries."""
Lianmin Zheng's avatar
Lianmin Zheng committed
456
    random.seed(seed)
457
    np.random.seed(seed)
Lianmin Zheng's avatar
Lianmin Zheng committed
458
459
460
461
462
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


463
def is_port_available(port):
464
    """Return whether a port is available."""
465
466
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
467
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
468
            s.bind(("", port))
469
            s.listen(1)
470
471
472
            return True
        except socket.error:
            return False
TianYu GUO's avatar
TianYu GUO committed
473
474
        except OverflowError:
            return False
475
476


Yuanhan Zhang's avatar
Yuanhan Zhang committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
def decode_video_base64(video_base64):
    from PIL import Image

    # Decode the base64 string
    video_bytes = base64.b64decode(video_base64)

    # Placeholder for the start indices of each PNG image
    img_starts = []

    frame_format = "PNG"  # str(os.getenv('FRAME_FORMAT', "JPEG"))

    assert frame_format in [
        "PNG",
        "JPEG",
    ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"

    if frame_format == "PNG":
        # Find each PNG start signature to isolate images
        i = 0
        while i < len(video_bytes) - 7:  # Adjusted for the length of the PNG signature
            # Check if we found the start of a PNG file
            if (
                video_bytes[i] == 0x89
                and video_bytes[i + 1] == 0x50
                and video_bytes[i + 2] == 0x4E
                and video_bytes[i + 3] == 0x47
                and video_bytes[i + 4] == 0x0D
                and video_bytes[i + 5] == 0x0A
                and video_bytes[i + 6] == 0x1A
                and video_bytes[i + 7] == 0x0A
            ):
                img_starts.append(i)
                i += 8  # Skip the PNG signature
            else:
                i += 1
    else:
        # Find each JPEG start (0xFFD8) to isolate images
        i = 0
        while (
            i < len(video_bytes) - 1
        ):  # Adjusted for the length of the JPEG SOI signature
            # Check if we found the start of a JPEG file
            if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
                img_starts.append(i)
                # Move to the next byte to continue searching for the next image start
                i += 2
            else:
                i += 1

    frames = []
    for start_idx in img_starts:
        # Assuming each image is back-to-back, the end of one image is the start of another
        # The last image goes until the end of the byte string
        end_idx = (
            img_starts[img_starts.index(start_idx) + 1]
            if img_starts.index(start_idx) + 1 < len(img_starts)
            else len(video_bytes)
        )
        img_bytes = video_bytes[start_idx:end_idx]

        # Convert bytes to a PIL Image
        img = Image.open(BytesIO(img_bytes))

        # Convert PIL Image to a NumPy array
        frame = np.array(img)

        # Append the frame to the list of frames
        frames.append(frame)

    # Ensure there's at least one frame to avoid errors with np.stack
    if frames:
        return np.stack(frames, axis=0), img.size
    else:
        return np.array([]), (
            0,
            0,
        )  # Return an empty array and size tuple if no frames were found
Lianmin Zheng's avatar
Lianmin Zheng committed
554
555


Mick's avatar
Mick committed
556
557
558
559
560
561
562
563
564
565
566
567
def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray:
    # Use soundfile here, since librosa use it under the hood,
    # and librosa will not support audio loading in the future
    import soundfile as sf
    from scipy.signal import resample

    # Load audio data
    if isinstance(audio_file, bytes):
        audio, original_sr = sf.read(BytesIO(audio_file))
    elif audio_file.startswith("data:"):
        audio_file = audio_file.split(",")[1]
        audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
Mick's avatar
Mick committed
568
569
570
571
572
573
    elif audio_file.startswith("http://") or audio_file.startswith("https://"):
        timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
        response = requests.get(audio_file, stream=True, timeout=timeout)
        audio_file = BytesIO(response.content)
        response.close()
        audio, original_sr = sf.read(audio_file)
Mick's avatar
Mick committed
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
    elif isinstance(audio_file, str):
        audio, original_sr = sf.read(audio_file)
    else:
        raise ValueError(f"Invalid audio format: {audio_file}")

    # Resample audio if the original sample rate is different from the desired sample rate
    if original_sr != sr:
        num_samples = int(len(audio) * float(sr) / original_sr)
        audio = resample(audio, num_samples)

    # Convert to mono if requested and audio is stereo
    if mono and len(audio.shape) > 1:
        audio = np.mean(audio, axis=1)

    return audio

Lianmin Zheng's avatar
Lianmin Zheng committed
590

Mick's avatar
Mick committed
591
def encode_video(video_path, frame_count_limit=None):
Lianmin Zheng's avatar
Lianmin Zheng committed
592
593
594
    # Lazy import because decord is not available on some arm platforms.
    from decord import VideoReader, cpu

Mick's avatar
Mick committed
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
    if not os.path.exists(video_path):
        logger.error(f"Video {video_path} does not exist")
        return []

    if frame_count_limit == 0:
        return []

    def uniform_sample(l, n):
        gap = len(l) / n
        idxs = [int(i * gap + gap / 2) for i in range(n)]
        return [l[i] for i in idxs]

    vr = VideoReader(video_path, ctx=cpu(0))
    sample_fps = round(vr.get_avg_fps() / 1)  # FPS
    frame_indices = [i for i in range(0, len(vr), sample_fps)]
    if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
        frame_indices = uniform_sample(frame_indices, frame_count_limit)

    frames = vr.get_batch(frame_indices).asnumpy()
    frames = [Image.fromarray(v.astype("uint8")) for v in frames]
    return frames


618
def load_image(
619
    image_file: Union[Image.Image, str, bytes],
620
) -> tuple[Image.Image, tuple[int, int]]:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
621
    image = image_size = None
622
623
624
625
    if isinstance(image_file, Image.Image):
        image = image_file
        image_size = (image.width, image.height)
    elif isinstance(image_file, bytes):
626
627
        image = Image.open(BytesIO(image_file))
    elif image_file.startswith("http://") or image_file.startswith("https://"):
Lianmin Zheng's avatar
Lianmin Zheng committed
628
        timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
629
630
631
        response = requests.get(image_file, stream=True, timeout=timeout).raw
        image = Image.open(response)
        response.close()
Lianmin Zheng's avatar
Lianmin Zheng committed
632
633
634
    elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
        image = Image.open(image_file)
    elif image_file.startswith("data:"):
635
        image_file = image_file.split(",")[1]
Lianmin Zheng's avatar
Lianmin Zheng committed
636
        image = Image.open(BytesIO(base64.b64decode(image_file)))
Yuanhan Zhang's avatar
Yuanhan Zhang committed
637
638
639
    elif image_file.startswith("video:"):
        image_file = image_file.replace("video:", "")
        image, image_size = decode_video_base64(image_file)
640
    elif isinstance(image_file, str):
Lianmin Zheng's avatar
Lianmin Zheng committed
641
        image = Image.open(BytesIO(base64.b64decode(image_file)))
642
643
    else:
        raise ValueError(f"Invalid image: {image}")
Lianmin Zheng's avatar
Lianmin Zheng committed
644

Yuanhan Zhang's avatar
Yuanhan Zhang committed
645
    return image, image_size
646
647


648
def suppress_other_loggers():
649
650
651
652
    warnings.filterwarnings(
        "ignore", category=UserWarning, message="The given NumPy array is not writable"
    )

Yineng Zhang's avatar
Yineng Zhang committed
653
654
655
656
    try:
        from vllm.logger import logger as vllm_default_logger
    except ImportError:
        return
657
658

    vllm_default_logger.setLevel(logging.WARN)
659
660
661
    logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
        logging.WARN
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
662
663
664
    logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
        logging.WARN
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
665
    logging.getLogger("vllm.config").setLevel(logging.ERROR)
666
667


668
def assert_pkg_version(pkg: str, min_version: str, message: str):
669
670
671
672
    try:
        installed_version = version(pkg)
        if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
            raise Exception(
673
                f"{pkg} is installed with version {installed_version}, which "
Ying Sheng's avatar
Ying Sheng committed
674
                f"is less than the minimum required version {min_version}. " + message
675
676
            )
    except PackageNotFoundError:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
677
        raise Exception(
Ying Sheng's avatar
Ying Sheng committed
678
679
            f"{pkg} with minimum required version {min_version} is not installed. "
            + message
Yuanhan Zhang's avatar
Yuanhan Zhang committed
680
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
681
682


683
684
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
    """Kill the process and all its child processes."""
685
686
687
688
    # Remove sigchld handler to avoid spammy logs.
    if threading.current_thread() is threading.main_thread():
        signal.signal(signal.SIGCHLD, signal.SIG_DFL)

689
690
691
    if parent_pid is None:
        parent_pid = os.getpid()
        include_parent = False
Lianmin Zheng's avatar
Lianmin Zheng committed
692

693
    try:
694
        itself = psutil.Process(parent_pid)
695
696
697
    except psutil.NoSuchProcess:
        return

Lianmin Zheng's avatar
Lianmin Zheng committed
698
    children = itself.children(recursive=True)
699
    for child in children:
700
701
        if child.pid == skip_pid:
            continue
702
703
704
705
706
        try:
            child.kill()
        except psutil.NoSuchProcess:
            pass

707
    if include_parent:
708
        try:
Lianmin Zheng's avatar
Lianmin Zheng committed
709
710
711
712
            if parent_pid == os.getpid():
                itself.kill()
                sys.exit(0)

713
            itself.kill()
714

715
716
717
718
719
            # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
            # so we send an additional signal to kill them.
            itself.send_signal(signal.SIGQUIT)
        except psutil.NoSuchProcess:
            pass
720
721


722
def monkey_patch_p2p_access_check():
723
    """
724
    Monkey patch the slow p2p access check.
725
726
727
    NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
    """

728
    import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt
Liangsheng Yin's avatar
Liangsheng Yin committed
729

730
    setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
731

Lianmin Zheng's avatar
Lianmin Zheng committed
732
    # Suppress the warnings from this delete function when using sglang.bench_one_batch
733
734
735
    from sglang.srt.distributed.device_communicators.custom_all_reduce import (
        CustomAllreduce,
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
736
737
738

    setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)

739

740
def monkey_patch_vllm_gguf_config():
Yineng Zhang's avatar
Yineng Zhang committed
741
742
743
744
745
746
747
748
    try:
        from vllm.model_executor.layers.quantization.gguf import (
            GGUFConfig,
            GGUFEmbeddingMethod,
            GGUFLinearMethod,
        )
    except ImportError:
        return
749

Yineng Zhang's avatar
Yineng Zhang committed
750
    from sglang.srt.layers.linear import LinearBase
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
    from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding

    def get_quant_method_with_embedding_replaced(
        self, layer: torch.nn.Module, prefix: str
    ) -> Optional["QuantizeMethodBase"]:
        if isinstance(layer, LinearBase):
            return GGUFLinearMethod(self)
        elif isinstance(layer, VocabParallelEmbedding):
            # patch to own VocabParallelEmbedding
            return GGUFEmbeddingMethod(self)
        return None

    setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)


766
767
768
769
770
771
def maybe_set_triton_cache_manager() -> None:
    """Set environment variable to tell Triton to use a
    custom cache manager"""
    cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
    if cache_manger is None:
        manager = "sglang.srt.utils:CustomCacheManager"
772
        logger.debug("Setting Triton cache manager to: %s", manager)
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
        os.environ["TRITON_CACHE_MANAGER"] = manager


class CustomCacheManager(FileCacheManager):
    # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
    def __init__(self, key, override=False, dump=False):

        self.key = key
        self.lock_path = None
        if dump:
            self.cache_dir = default_dump_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
            self.lock_path = os.path.join(self.cache_dir, "lock")
            os.makedirs(self.cache_dir, exist_ok=True)
        elif override:
            self.cache_dir = default_override_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
        else:
            # create cache directory if it doesn't exist
            self.cache_dir = (
                os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
            )
            if self.cache_dir:
                self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
                self.cache_dir = os.path.join(self.cache_dir, self.key)
                self.lock_path = os.path.join(self.cache_dir, "lock")
                os.makedirs(self.cache_dir, exist_ok=True)
            else:
                raise RuntimeError("Could not create or locate cache dir")


804
805
806
807
808
809
810
811
def set_ulimit(target_soft_limit=65535):
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
        except ValueError as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
812
            logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
813
814


815
def add_api_key_middleware(app, api_key: str):
816
817
818
819
820
821
    @app.middleware("http")
    async def authentication(request, call_next):
        if request.method == "OPTIONS":
            return await call_next(request)
        if request.url.path.startswith("/health"):
            return await call_next(request)
822
823
        if request.url.path.startswith("/metrics"):
            return await call_next(request)
824
        if request.headers.get("Authorization") != "Bearer " + api_key:
825
            return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
826
        return await call_next(request)
827
828


829
def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
830
    if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
831
832
833
        if not os.path.exists(model_path):
            from modelscope import snapshot_download

834
835
            model_path = snapshot_download(model_path)
            tokenizer_path = snapshot_download(
836
837
                tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
            )
838
    return model_path, tokenizer_path
839
840
841


def configure_logger(server_args, prefix: str = ""):
842
843
844
845
846
847
848
849
850
851
    if SGLANG_LOGGING_CONFIG_PATH := os.getenv("SGLANG_LOGGING_CONFIG_PATH"):
        if not os.path.exists(SGLANG_LOGGING_CONFIG_PATH):
            raise Exception(
                "Setting SGLANG_LOGGING_CONFIG_PATH from env with "
                f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
            )
        with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
            custom_config = json.loads(file.read())
        logging.config.dictConfig(custom_config)
        return
852
    format = f"[%(asctime)s{prefix}] %(message)s"
Lianmin Zheng's avatar
Lianmin Zheng committed
853
    # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
854
855
856
    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format=format,
857
        datefmt="%Y-%m-%d %H:%M:%S",
858
859
        force=True,
    )
860
861
862
863
864
865
866
867
868
869
870


# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890


def set_weight_attrs(
    weight: torch.Tensor,
    weight_attrs: Optional[Dict[str, Any]],
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
        assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
        setattr(weight, key, value)
891
892
893


def broadcast_pyobj(
894
895
896
    data: List[Any],
    rank: int,
    dist_group: Optional[torch.distributed.ProcessGroup] = None,
897
    src: int = 0,
898
    force_cpu_device: bool = True,
899
):
900
901
902
903
    """Broadcast inputs from src rank to all other ranks with torch.dist backend.
    The `rank` here refer to the source rank on global process group (regardless
    of dist_group argument).
    """
904
905
906
    device = torch.device(
        "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu"
    )
907

908
    if rank == src:
909
        if len(data) == 0:
910
            tensor_size = torch.tensor([0], dtype=torch.long, device=device)
911
            dist.broadcast(tensor_size, src=src, group=dist_group)
912
913
914
        else:
            serialized_data = pickle.dumps(data)
            size = len(serialized_data)
915

916
917
            tensor_data = torch.ByteTensor(
                np.frombuffer(serialized_data, dtype=np.uint8)
918
919
            ).to(device)
            tensor_size = torch.tensor([size], dtype=torch.long, device=device)
920

921
922
            dist.broadcast(tensor_size, src=src, group=dist_group)
            dist.broadcast(tensor_data, src=src, group=dist_group)
923
924
        return data
    else:
925
        tensor_size = torch.tensor([0], dtype=torch.long, device=device)
926
        dist.broadcast(tensor_size, src=src, group=dist_group)
927
928
929
930
931
        size = tensor_size.item()

        if size == 0:
            return []

932
        tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
933
        dist.broadcast(tensor_data, src=src, group=dist_group)
934

935
        serialized_data = bytes(tensor_data.cpu().numpy())
936
937
        data = pickle.loads(serialized_data)
        return data
938
939


940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
def point_to_point_pyobj(
    data: List[Any],
    rank: int,
    group: Optional[torch.distributed.ProcessGroup] = None,
    src: int = 0,
    dst: int = 1,
):
    """Send data from src to dst in group."""

    if rank == src:
        if len(data) == 0:
            tensor_size = torch.tensor([0], dtype=torch.long)
            dist.send(tensor_size, dst=dst, group=group)
        else:
            serialized_data = pickle.dumps(data)
            size = len(serialized_data)
            tensor_data = torch.ByteTensor(
                np.frombuffer(serialized_data, dtype=np.uint8)
            )
            tensor_size = torch.tensor([size], dtype=torch.long)

            dist.send(tensor_size, dst=dst, group=group)
            dist.send(tensor_data, dst=dst, group=group)
        return data

    elif rank == dst:
        tensor_size = torch.tensor([0], dtype=torch.long)
        dist.recv(tensor_size, src=src, group=group)
        size = tensor_size.item()

        if size == 0:
            return []

        tensor_data = torch.empty(size, dtype=torch.uint8)
        dist.recv(tensor_data, src=src, group=group)

        serialized_data = bytes(tensor_data.cpu().numpy())
        data = pickle.loads(serialized_data)
        return data

    # Other ranks in pp_group do nothing
    return []


984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
step_counter = 0


def pytorch_profile(name, func, *args, data_size=-1):
    """
    Args:
        name (string): the name of recorded function.
        func: the function to be profiled.
        args: the arguments of the profiled function.
        data_size (int): some measurement of the computation complexity.
            Usually, it could be the batch size.
    """
    global step_counter
    os.makedirs("trace", exist_ok=True)
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
        # on_trace_ready=tensorboard_trace_handler('./log_dir'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as prof:
        with record_function(name):
            with open(f"trace/size_{step_counter}.json", "w") as f:
                json.dump({"size": data_size}, f)
            result = func(*args)
    prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
    step_counter += 1
    return result
1013
1014


Lianmin Zheng's avatar
Lianmin Zheng committed
1015
1016
1017
def get_zmq_socket(
    context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
):
1018
1019
1020
1021
1022
1023
1024
1025
    mem = psutil.virtual_memory()
    total_mem = mem.total / 1024**3
    available_mem = mem.available / 1024**3
    if total_mem > 32 and available_mem > 16:
        buf_size = int(0.5 * 1024**3)
    else:
        buf_size = -1

1026
    socket = context.socket(socket_type)
1027
1028
    if endpoint.find("[") != -1:
        socket.setsockopt(zmq.IPV6, 1)
1029
1030

    def set_send_opt():
1031
        socket.setsockopt(zmq.SNDHWM, 0)
1032
        socket.setsockopt(zmq.SNDBUF, buf_size)
1033
1034

    def set_recv_opt():
1035
        socket.setsockopt(zmq.RCVHWM, 0)
1036
        socket.setsockopt(zmq.RCVBUF, buf_size)
1037
1038
1039
1040
1041
1042
1043
1044

    if socket_type == zmq.PUSH:
        set_send_opt()
    elif socket_type == zmq.PULL:
        set_recv_opt()
    elif socket_type == zmq.DEALER:
        set_send_opt()
        set_recv_opt()
1045
1046
1047
    else:
        raise ValueError(f"Unsupported socket type: {socket_type}")

Lianmin Zheng's avatar
Lianmin Zheng committed
1048
1049
1050
1051
1052
    if bind:
        socket.bind(endpoint)
    else:
        socket.connect(endpoint)

1053
    return socket
1054
1055
1056


def dump_to_file(dirpath, name, value):
1057
    from sglang.srt.distributed import get_tensor_model_parallel_rank
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094

    if get_tensor_model_parallel_rank() != 0:
        return

    os.makedirs(dirpath, exist_ok=True)
    if value.dtype is torch.bfloat16:
        value = value.float()
    value = value.cpu().numpy()
    output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy")
    logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}")
    np.save(output_filename, value)


def is_triton_3():
    return triton.__version__.startswith("3.")


def maybe_torch_compile(*args, **kwargs):
    """
    torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax.
    Therefore, we disable it here.
    """

    def decorator(func):
        if is_triton_3():
            return torch.compile(*args, **kwargs)(func)
        return func

    return decorator


def delete_directory(dirpath):
    try:
        # This will remove the directory and all its contents
        shutil.rmtree(dirpath)
    except OSError as e:
        print(f"Warning: {dirpath} : {e.strerror}")
Lianmin Zheng's avatar
Lianmin Zheng committed
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120


# Temporary directory for prometheus multiprocess mode
# Cleaned up automatically when this object is garbage collected
prometheus_multiproc_dir: tempfile.TemporaryDirectory


def set_prometheus_multiproc_dir():
    # Set prometheus multiprocess directory
    # sglang uses prometheus multiprocess mode
    # we need to set this before importing prometheus_client
    # https://prometheus.github.io/client_python/multiprocess/
    global prometheus_multiproc_dir

    if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
        logger.debug("User set PROMETHEUS_MULTIPROC_DIR detected.")
        prometheus_multiproc_dir = tempfile.TemporaryDirectory(
            dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
        )
    else:
        prometheus_multiproc_dir = tempfile.TemporaryDirectory()
        os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
    logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")


def add_prometheus_middleware(app):
1121
    # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
Lianmin Zheng's avatar
Lianmin Zheng committed
1122
1123
1124
1125
1126
1127
1128
1129
1130
    from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess

    registry = CollectorRegistry()
    multiprocess.MultiProcessCollector(registry)
    metrics_route = Mount("/metrics", make_asgi_app(registry=registry))

    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
    app.routes.append(metrics_route)
1131
1132


1133
1134
1135
1136
1137
1138
1139
1140
1141
def bind_port(port):
    """Bind to a specific port, assuming it's available."""
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)  # Allows address reuse
    sock.bind(("", port))
    sock.listen(1)
    return sock


HAI's avatar
HAI committed
1142
1143
1144
1145
def get_amdgpu_memory_capacity():
    try:
        # Run rocm-smi and capture the output
        result = subprocess.run(
1146
            [
HAI's avatar
HAI committed
1147
                "rocminfo | grep 'gfx' -A 100 | grep 'Pool 1' -A 5 | grep 'Size:' | awk '{print $2}'"
1148
            ],
HAI's avatar
HAI committed
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            shell=True,
            text=True,
        )
        if result.returncode != 0:
            raise RuntimeError(f"rocm-smi error: {result.stderr.strip()}")

        # Parse the output to extract memory values in MiB
        memory_values = [
1159
            float(mem.split("(")[0].strip()) / 1024
HAI's avatar
HAI committed
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
            for mem in result.stdout.strip().split("\n")
        ]

        if not memory_values:
            raise ValueError("No GPU memory values found.")

        # Return the minimum memory value
        return min(memory_values)

    except FileNotFoundError:
        raise RuntimeError(
            "rocm-smi not found. Ensure AMD ROCm drivers are installed and accessible."
        )


1175
1176
1177
1178
1179
1180
1181
def get_device_sm():
    if torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability()
        return major * 10 + minor
    return 0


HAI's avatar
HAI committed
1182
def get_nvgpu_memory_capacity():
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
    try:
        # Run nvidia-smi and capture the output
        result = subprocess.run(
            ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )

        if result.returncode != 0:
            raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}")

        # Parse the output to extract memory values
        memory_values = [
            float(mem)
            for mem in result.stdout.strip().split("\n")
            if re.match(r"^\d+(\.\d+)?$", mem.strip())
        ]

        if not memory_values:
            raise ValueError("No GPU memory values found.")

        # Return the minimum memory value
        return min(memory_values)

    except FileNotFoundError:
        raise RuntimeError(
            "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
        )
1212
1213


1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
def get_hpu_memory_capacity():
    try:
        # Run hl-smi and capture the output
        result = subprocess.run(
            ["hl-smi --query | grep 'Total'"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            shell=True,
            text=True,
        )

        if result.returncode != 0:
            raise RuntimeError(f"hl-smi error: {result.stderr.strip()}")

        # Parse the output to extract memory values in MiB
        memory_values = [
            float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n")
        ]

        if not memory_values:
            raise ValueError("No GPU memory values found.")

        # Return the minimum memory value
        return min(memory_values)

    except FileNotFoundError:
        raise RuntimeError(
            "hl-smi not found. Ensure Habana drivers are installed and accessible."
        )


Lianmin Zheng's avatar
Lianmin Zheng committed
1245
def get_device_memory_capacity(device: str = None):
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
    if is_cuda():
        gpu_mem = get_nvgpu_memory_capacity()
    elif is_hip():
        gpu_mem = get_amdgpu_memory_capacity()
    elif device == "hpu":
        gpu_mem = get_hpu_memory_capacity()
    else:
        # GPU memory is not known yet or no GPU is available.
        gpu_mem = None

    return gpu_mem


1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
def init_custom_process_group(
    backend=None,
    init_method=None,
    timeout=None,
    world_size=-1,
    rank=-1,
    store=None,
    group_name=None,
    pg_options=None,
):
    from torch.distributed.distributed_c10d import (
        Backend,
        PrefixStore,
        _new_process_group_helper,
        _world,
        default_pg_timeout,
        rendezvous,
    )

    assert (store is None) or (
        init_method is None
    ), "Cannot specify both init_method and store."

    if store is not None:
        assert world_size > 0, "world_size must be positive if using store"
        assert rank >= 0, "rank must be non-negative if using store"
    elif init_method is None:
        init_method = "env://"

    if backend:
        backend = Backend(backend)
    else:
        backend = Backend("undefined")

    if timeout is None:
        timeout = default_pg_timeout

    # backward compatible API
    if store is None:
        rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
        store, rank, world_size = next(rendezvous_iterator)
        store.set_timeout(timeout)

        # Use a PrefixStore to avoid accidental overrides of keys used by
        # different systems (e.g. RPC) in case the store is multi-tenant.
        store = PrefixStore(group_name, store)

    # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
    # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
    # We need to determine the appropriate parameter name based on PyTorch version
    pg_options_param_name = (
        "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
    )
    pg, _ = _new_process_group_helper(
        world_size,
        rank,
        [],
        backend,
        store,
        group_name=group_name,
        **{pg_options_param_name: pg_options},
        timeout=timeout,
    )

    _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}

    return pg


1331
1332
def crash_on_warnings():
    # Crash on warning if we are running CI tests
1333
    return get_bool_env_var("SGLANG_IS_IN_CI")
1334
1335


1336
1337
1338
1339
1340
def print_warning_once(msg: str) -> None:
    # Set the stacklevel to 2 to print the caller's line info
    logger.warning(msg, stacklevel=2)


1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
def get_device_name(device_id: int = 0) -> str:
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        return torch.cuda.get_device_name(device_id)

    if hasattr(torch, "xpu") and torch.xpu.is_available():
        return torch.xpu.get_device_name(device_id)

    if hasattr(torch, "hpu") and torch.hpu.is_available():
        return torch.hpu.get_device_name(device_id)


1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
@lru_cache(maxsize=1)
def is_habana_available() -> bool:
    return find_spec("habana_frameworks") is not None


@lru_cache(maxsize=8)
def get_device(device_id: Optional[int] = None) -> str:
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        if device_id is None:
            return "cuda"
        return "cuda:{}".format(device_id)

    if hasattr(torch, "xpu") and torch.xpu.is_available():
        if device_id == None:
            return "xpu"
        return "xpu:{}".format(device_id)

    if is_habana_available():
        try:
            import habana_frameworks.torch.hpu

            if torch.hpu.is_available():
                if device_id == None:
                    return "hpu"
                return "hpu:{}".format(device_id)
        except ImportError as e:
            raise ImportError(
                "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
            )

    raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")


@lru_cache(maxsize=1)
def get_device_count() -> int:
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        try:
            return torch.cuda.device_count()
        except RuntimeError:
            return 0

    if hasattr(torch, "xpu") and torch.xpu.is_available():
        try:
            return torch.xpu.device_count()
        except RuntimeError:
            return 0

    if is_habana_available():
        try:
            import habana_frameworks.torch.hpu

            if torch.hpu.is_available():
                return torch.hpu.device_count()
        except (ImportError, RuntimeError):
            return 0

    return 0  # No accelerators available


1411
1412
1413
1414
1415
1416
1417
def get_device_core_count(device_id: int = 0) -> int:
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        return torch.cuda.get_device_properties(device_id).multi_processor_count

    return 0


1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
    major, minor = None, None
    if hasattr(torch, "cuda") and torch.cuda.is_available():
        major, minor = torch.cuda.get_device_capability(device_id)

    if hasattr(torch, "xpu") and torch.xpu.is_available():
        major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
            "."
        )
        major, minor = int(major), int(minor)

    if hasattr(torch, "hpu") and torch.hpu.is_available():
        try:
1431
1432
1433
1434
            # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
            # Update this once the support is available.
            # major, minor = torch.hpu.get_device_capability(device_id)
            major, minor = None, None
1435
1436
1437
1438
1439
1440
1441
1442
        except Exception as e:
            raise RuntimeError(
                f"An error occurred while getting device capability of hpu: {e}."
            ) from e

    return major, minor


1443
1444
1445
1446
1447
1448
1449
def get_compiler_backend() -> str:
    if hasattr(torch, "hpu") and torch.hpu.is_available():
        return "hpu_backend"

    return "inductor"


1450
1451
1452
sglang_lib = Library("sglang", "FRAGMENT")  # noqa


1453
1454
1455
1456
1457
1458
# Some backends use pytorch version < 2.4.0 which doesn't
# support `torch.library.custom_op`.
def supports_custom_op() -> bool:
    return hasattr(torch.library, "custom_op")


1459
1460
1461
1462
1463
1464
1465
def direct_register_custom_op(
    op_name: str,
    op_func: Callable,
    mutates_args: List[str],
    fake_impl: Optional[Callable] = None,
    target_lib: Optional[Library] = None,
):
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
    """
    `torch.library.custom_op` can have significant overhead because it
    needs to consider complicated dispatching logic. This function
    directly registers a custom op and dispatches it to the CUDA backend.
    See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
    for more details.

    By default, the custom op is registered to the vLLM library. If you
    want to register it to a different library, you can pass the library
    object to the `target_lib` argument.

    IMPORTANT: the lifetime of the operator is tied to the lifetime of the
    library object. If you want to bind the operator to a different library,
    make sure the library object is alive when the operator is used.
    """
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
    import torch.library

    if hasattr(torch.library, "infer_schema"):
        schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
    else:
        # for pytorch 2.4
        import torch._custom_op.impl

        schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)

    my_lib = target_lib or sglang_lib
    my_lib.define(op_name + schema_str)
    my_lib.impl(op_name, op_func, "CUDA")
    if fake_impl is not None:
        my_lib._register_fake(op_name, fake_impl)
1496
1497


1498
def set_gpu_proc_affinity(
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
    tp_size: int,
    nnodes: int,
    gpu_id: int,
):
    # current process
    pid = os.getpid()
    p = psutil.Process(pid)

    tp_size_per_node = tp_size // nnodes

    # total physical cores
    total_pcores = psutil.cpu_count(logical=False)
    # physical cores per TP (N.B. more Cores than GPUs on node)
    num_cores_bind = total_pcores // tp_size_per_node

    # able to handle multiple DP per node
    start_cpu_id = (gpu_id * num_cores_bind) % total_pcores
    end_cpu_id = start_cpu_id + num_cores_bind

    if psutil.cpu_count() != psutil.cpu_count(logical=False):
        # HT on
Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
1520
1521
1522
        lower_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
        upper_cpu_ids = [id + total_pcores for id in range(start_cpu_id, end_cpu_id)]
        bind_cpu_ids = list(itertools.chain(lower_cpu_ids, upper_cpu_ids))
1523
1524
1525
1526
1527
1528
1529
    else:
        # HT off
        bind_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]

    # set cpu_affinity to current process
    p.cpu_affinity(bind_cpu_ids)
    logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")
1530
1531


1532
1533
1534
1535
1536
@lru_cache(maxsize=2)
def disable_request_logging() -> bool:
    return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")


1537
1538
1539
1540
1541
def dataclass_to_string_truncated(
    data, max_length=2048, skip_names: Optional[Set[str]] = None
):
    if skip_names is None:
        skip_names = set()
1542
1543
1544
    if isinstance(data, str):
        if len(data) > max_length:
            half_length = max_length // 2
1545
            return f"{repr(data[:half_length])} ... {repr(data[-half_length:])}"
1546
        else:
1547
            return f"{repr(data)}"
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
    elif isinstance(data, (list, tuple)):
        if len(data) > max_length:
            half_length = max_length // 2
            return str(data[:half_length]) + " ... " + str(data[-half_length:])
        else:
            return str(data)
    elif isinstance(data, dict):
        return (
            "{"
            + ", ".join(
1558
                f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
1559
                for k, v in data.items()
1560
                if k not in skip_names
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
            )
            + "}"
        )
    elif dataclasses.is_dataclass(data):
        fields = dataclasses.fields(data)
        return (
            f"{data.__class__.__name__}("
            + ", ".join(
                f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
                for f in fields
1571
                if f.name not in skip_names
1572
1573
1574
            )
            + ")"
        )
1575
    else:
1576
        return str(data)
Tanjiro's avatar
Tanjiro committed
1577
1578


1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
def permute_weight(x: torch.Tensor) -> torch.Tensor:
    b_ = x.shape[0]
    n_ = x.shape[1]
    k_ = x.shape[2]

    x_ = x
    if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
        x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
    elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
        x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
    else:
1590
1591
        # return x_
        x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 8), 2, 4)
1592
1593
1594
1595
1596
1597
1598

    x_ = x_.permute(0, 1, 3, 4, 2, 5)
    x_ = x_.contiguous()
    x_ = x_.view(*x.shape)
    return x_


1599
1600
class MultiprocessingSerializer:
    @staticmethod
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
    def serialize(obj, output_str: bool = False):
        """
        Serialize a Python object using ForkingPickler.

        Args:
            obj: The object to serialize.
            output_str (bool): If True, return a base64-encoded string instead of raw bytes.

        Returns:
            bytes or str: The serialized object.
        """
1612
1613
1614
        buf = io.BytesIO()
        ForkingPickler(buf).dump(obj)
        buf.seek(0)
1615
1616
1617
1618
1619
1620
1621
        output = buf.read()

        if output_str:
            # Convert bytes to base64-encoded string
            output = base64.b64encode(output).decode("utf-8")

        return output
1622
1623
1624

    @staticmethod
    def deserialize(data):
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
        """
        Deserialize a previously serialized object.

        Args:
            data (bytes or str): The serialized data, optionally base64-encoded.

        Returns:
            The deserialized Python object.
        """
        if isinstance(data, str):
            # Decode base64 string to bytes
            data = base64.b64decode(data)

1638
        return ForkingPickler.loads(data)
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649


def debug_timing(func):
    # todo: replace with a more organized instrumentation
    def wrapper(*args, **kwargs):
        if logger.isEnabledFor(logging.DEBUG):
            tic = torch.cuda.Event(enable_timing=True)
            toc = torch.cuda.Event(enable_timing=True)
            tic.record()
            result = func(*args, **kwargs)
            toc.record()
1650
            toc.synchronize()  # Wait for the function to complete without synchronizing all ops on the GPU
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
            elapsed = tic.elapsed_time(toc)
            indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
            num_tokens = len(indices) if indices is not None else 0
            throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0
            logger.debug(
                f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s"
            )
            return result
        else:
            return func(*args, **kwargs)

    return wrapper
bjmsong's avatar
bjmsong committed
1663
1664
1665
1666
1667
1668


def nullable_str(val: str):
    if not val or val == "None":
        return None
    return val
1669
1670


1671
1672
1673
1674
1675
1676
1677
1678
1679
def pyspy_dump_schedulers():
    """py-spy dump on all scheduler in a local node."""
    try:
        pid = psutil.Process().pid
        # Command to run py-spy with the PID
        cmd = f"py-spy dump --pid {pid}"
        result = subprocess.run(
            cmd, shell=True, capture_output=True, text=True, check=True
        )
1680
        logger.error(f"Pyspy dump for PID {pid}:\n{result.stdout}")
1681
    except subprocess.CalledProcessError as e:
1682
        logger.error(f"Pyspy failed to dump PID {pid}. Error: {e.stderr}")
1683
1684
1685
1686
1687
1688
1689
1690
1691


def kill_itself_when_parent_died():
    if sys.platform == "linux":
        # sigkill this process when parent worker manager dies
        PR_SET_PDEATHSIG = 1
        libc = ctypes.CDLL("libc.so.6")
        libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)
    else:
Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
1692
        logger.warning("kill_itself_when_parent_died is only supported in linux.")
1693
1694


1695
def set_uvicorn_logging_configs():
1696
1697
    from uvicorn.config import LOGGING_CONFIG

1698
1699
1700
1701
1702
1703
1704
1705
    LOGGING_CONFIG["formatters"]["default"][
        "fmt"
    ] = "[%(asctime)s] %(levelprefix)s %(message)s"
    LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
    LOGGING_CONFIG["formatters"]["access"][
        "fmt"
    ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
    LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745


def get_ip() -> str:
    # SGLANG_HOST_IP env can be ignore
    host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
    if host_ip:
        return host_ip

    # IP is not set, try to get it from the network interface

    # try ipv4
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
        pass

    # try ipv6
    try:
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
        pass

    warnings.warn(
        "Failed to get the IP address, using 0.0.0.0 by default."
        "The value can be set by the environment variable"
        " SGLANG_HOST_IP or HOST_IP.",
        stacklevel=2,
    )
    return "0.0.0.0"


def get_open_port() -> int:
    port = os.getenv("SGLANG_PORT")
    if port is not None:
1746
        port = int(port)
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
        while True:
            try:
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                    s.bind(("", port))
                    return port
            except OSError:
                port += 1  # Increment port number if already in use
                logger.info("Port %d is already in use, trying port %d", port - 1, port)
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]


def is_valid_ipv6_address(address: str) -> bool:
    try:
        ipaddress.IPv6Address(address)
        return True
    except ValueError:
        return False
1773
1774


Vincent's avatar
Vincent committed
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
def configure_ipv6(dist_init_addr):
    addr = dist_init_addr
    end = addr.find("]")
    if end == -1:
        raise ValueError("invalid IPv6 address format: missing ']'")

    host = addr[: end + 1]

    # this only validates the address without brackets: we still need the below checks.
    # if it's invalid, immediately raise an error so we know it's not formatting issues.
    if not is_valid_ipv6_address(host[1:end]):
        raise ValueError(f"invalid IPv6 address: {host}")

    port_str = None
    if len(addr) > end + 1:
        if addr[end + 1] == ":":
            port_str = addr[end + 2 :]
        else:
            raise ValueError("received IPv6 address format: expected ':' after ']'")

    if not port_str:
        raise ValueError(
            "a port must be specified in IPv6 address (format: [ipv6]:port)"
        )

    try:
        port = int(port_str)
    except ValueError:
        raise ValueError(f"invalid port in IPv6 address: '{port_str}'")
    return port, host


1807
1808
1809
1810
1811
1812
1813
def rank0_log(msg: str):
    from sglang.srt.distributed import get_tensor_model_parallel_rank

    if get_tensor_model_parallel_rank() == 0:
        logger.info(msg)


1814
1815
1816
1817
1818
def rank0_print(msg: str):
    from sglang.srt.distributed import get_tensor_model_parallel_rank

    if get_tensor_model_parallel_rank() == 0:
        print(msg, flush=True)
1819
1820


HandH1998's avatar
HandH1998 committed
1821
1822
1823
1824
1825
1826
def get_cuda_version():
    if torch.version.cuda:
        return tuple(map(int, torch.version.cuda.split(".")))
    return (0, 0)


1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
def launch_dummy_health_check_server(host, port):
    import uvicorn
    from fastapi import FastAPI, Response

    app = FastAPI()

    @app.get("/health")
    async def health():
        """Check the health of the http server."""
        return Response(status_code=200)

    @app.get("/health_generate")
    async def health_generate():
        """Check the health of the http server."""
        return Response(status_code=200)

    uvicorn.run(
        app,
        host=host,
        port=port,
        timeout_keep_alive=5,
        loop="uvloop",
    )
1850
1851


1852
1853
1854
1855
def create_checksum(directory: str):
    raise NotImplementedError()


1856
1857
1858
1859
1860
def set_cuda_arch():
    if is_flashinfer_available():
        capability = torch.cuda.get_device_capability()
        arch = f"{capability[0]}.{capability[1]}"
        os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
1861
1862


Lianmin Zheng's avatar
Lianmin Zheng committed
1863
1864
1865
1866
1867
1868
1869
def next_power_of_2(n: int):
    return 1 << (n - 1).bit_length() if n > 0 else 1


setattr(triton, "next_power_of_2", next_power_of_2)


1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
@contextmanager
def empty_context(*args, **kwargs):
    try:
        # Setup code goes here
        yield
    finally:
        # Cleanup code goes here
        pass


1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
def add_prefix(name: str, prefix: str) -> str:
    """Add a weight path prefix to a module name.

    Args:
        name: base module name.
        prefix: weight prefix str to added to the front of `name` concatenated with `.`.

    Returns:
        The string `prefix.name` if prefix is non-empty, otherwise just `name`.
    """
    return name if not prefix else f"{prefix}.{name}"
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916


def is_remote_url(url: Union[str, Path]) -> bool:
    """
    Check if the URL is a remote URL of the format:
    <connector_type>://<host>:<port>/<model_name>
    """
    if isinstance(url, Path):
        return False

    pattern = r"(.+)://(.*)"
    m = re.match(pattern, url)
    return m is not None


def parse_connector_type(url: str) -> str:
    """
    Parse the connector type from the URL of the format:
    <connector_type>://<path>
    """
    pattern = r"(.+)://(.*)"
    m = re.match(pattern, url)
    if m is None:
        return ""

    return m.group(1)
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945


def retry(
    fn,
    max_retry: int,
    initial_delay: float = 2.0,
    max_delay: float = 60.0,
    should_retry: Callable[[Any], bool] = lambda e: True,
):
    for try_index in itertools.count():
        try:
            return fn()
        except Exception as e:
            if try_index >= max_retry:
                raise Exception(f"retry() exceed maximum number of retries.")

            if not should_retry(e):
                raise Exception(f"retry() observe errors that should not be retried.")

            delay = min(initial_delay * (2**try_index), max_delay) * (
                0.75 + 0.25 * random.random()
            )

            logger.warning(
                f"retry() failed once ({try_index}th try, maximum {max_retry} retries). Will delay {delay:.2f}s and retry. Error: {e}"
            )
            traceback.print_exc()

            time.sleep(delay)
Mick's avatar
Mick committed
1946
1947
1948
1949
1950
1951
1952
1953
1954


def flatten_nested_list(nested_list):
    if isinstance(nested_list, list):
        return [
            item for sublist in nested_list for item in flatten_nested_list(sublist)
        ]
    else:
        return [nested_list]
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975


class DeepEPMode(Enum):
    normal = "normal"
    low_latency = "low_latency"
    auto = "auto"

    def enable_normal(self):
        return self in [DeepEPMode.normal, DeepEPMode.auto]

    def enable_low_latency(self):
        return self in [DeepEPMode.low_latency, DeepEPMode.auto]

    def resolve(self, forward_mode):
        if self != DeepEPMode.auto:
            return self

        if forward_mode.is_decode():
            return DeepEPMode.low_latency
        else:
            return DeepEPMode.normal
1976
1977
1978
1979
1980
1981
1982
1983
1984


def fast_topk(values, topk, dim):
    if topk == 1:
        # Use max along the specified dimension to get both value and index
        return torch.max(values, dim=dim, keepdim=True)
    else:
        # Use topk for efficiency with larger k values
        return torch.topk(values, topk, dim=dim)
1985
1986


1987
def _check(cc_major):
1988
1989
    if not is_cuda():
        return False
1990
1991
1992
1993
1994
1995
1996
    return torch.cuda.get_device_capability()[0] == cc_major and tuple(
        map(int, torch.version.cuda.split(".")[:2])
    ) >= (12, 3)


is_ampere_with_cuda_12_3 = lambda: _check(8)
is_hopper_with_cuda_12_3 = lambda: _check(9)
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028


def get_free_port():
    # try ipv4
    try:
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]
    except OSError:
        # try ipv6
        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
            s.bind(("", 0))
            return s.getsockname()[1]


def get_local_ip_by_remote() -> str:
    # try ipv4
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
        pass

    # try ipv6
    try:
        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
        # Google's public DNS server, see
        # https://developers.google.com/speed/public-dns/docs/using#addresses
        s.connect(("2001:4860:4860::8888", 80))  # Doesn't need to be reachable
        return s.getsockname()[0]
    except Exception:
Lianmin Zheng's avatar
Lianmin Zheng committed
2029
        raise ValueError("Can not get local ip")
2030
2031
2032
2033
2034
2035


def is_page_size_one(server_args):
    return server_args.page_size == 1


2036
2037
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
def is_no_spec_infer_or_topk_one(server_args):
    return server_args.speculative_eagle_topk is None or (
        server_args.speculative_eagle_topk is not None
        and server_args.speculative_eagle_topk == 1
        and is_page_size_one(server_args)
    )


def is_fa3_default_architecture(hf_config):
    architectures = getattr(hf_config, "architectures", None)
    if not isinstance(architectures, list) or not architectures:
        return False
    default_archs = {
        "Qwen2ForCausalLM",
        "Llama4ForConditionalGeneration",
        "LlamaForCausalLM",
        "MistralForCausalLM",
2055
        "MixtralForCausalLM",
Yineng Zhang's avatar
Yineng Zhang committed
2056
        "Gemma2ForCausalLM",
2057
        "Gemma3ForConditionalGeneration",
2058
2059
        "Qwen3ForCausalLM",
        "Qwen3MoeForCausalLM",
2060
2061
    }
    return architectures[0] in default_archs
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074


# Can be more general if it is used in multiple places (keep it simple and thus not general now)
class BumpAllocator:
    def __init__(self, buffer_size: int, dtype, device):
        self._buffer = torch.zeros((buffer_size,), dtype=dtype, device=device)
        self._pointer = 0

    def allocate(self, size: int):
        assert self._pointer + size <= len(self._buffer)
        output = self._buffer[self._pointer : self._pointer + size]
        self._pointer += size
        return output