utils.py 23.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
17
"""Common utilities."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
import base64
19
import ipaddress
20
import json
21
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
import pickle
Lianmin Zheng's avatar
Lianmin Zheng committed
24
import random
25
import resource
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
import socket
import time
28
import warnings
29
from importlib.metadata import PackageNotFoundError, version
Lianmin Zheng's avatar
Lianmin Zheng committed
30
from io import BytesIO
31
from typing import Any, Dict, List, Optional, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33

import numpy as np
34
import psutil
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36
import requests
import torch
37
import torch.distributed as dist
38
import zmq
39
from fastapi.responses import ORJSONResponse
40
from packaging import version as pkg_version
41
from torch import nn
42
from torch.profiler import ProfilerActivity, profile, record_function
43
44
45
46
47
48
from triton.runtime.cache import (
    FileCacheManager,
    default_cache_dir,
    default_dump_dir,
    default_override_dir,
)
49

50
51
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
52

Liangsheng Yin's avatar
Liangsheng Yin committed
53
54
show_time_cost = False
time_infos = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
55
56


57
def is_hip() -> bool:
58
    """Return whether it is HIP on the AMD ROCm platform."""
59
60
61
    return torch.version.hip is not None


62
63
64
65
66
67
68
69
def is_flashinfer_available():
    """
    Check whether flashinfer is available.
    As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
    """
    return torch.cuda.is_available() and not is_hip()


70
71
72
73
74
75
76
77
def is_ipv6(address):
    try:
        ipaddress.IPv6Address(address)
        return True
    except ipaddress.AddressValueError:
        return False


Liangsheng Yin's avatar
Liangsheng Yin committed
78
79
80
81
def enable_show_time_cost():
    global show_time_cost
    show_time_cost = True

Lianmin Zheng's avatar
Lianmin Zheng committed
82

Liangsheng Yin's avatar
Liangsheng Yin committed
83
84
85
86
87
88
class TimeInfo:
    def __init__(self, name, interval=0.1, color=0, indent=0):
        self.name = name
        self.interval = interval
        self.color = color
        self.indent = indent
Lianmin Zheng's avatar
Lianmin Zheng committed
89

Liangsheng Yin's avatar
Liangsheng Yin committed
90
91
        self.acc_time = 0
        self.last_acc_time = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
92

Liangsheng Yin's avatar
Liangsheng Yin committed
93
94
95
96
97
    def check(self):
        if self.acc_time - self.last_acc_time > self.interval:
            self.last_acc_time = self.acc_time
            return True
        return False
Lianmin Zheng's avatar
Lianmin Zheng committed
98

Liangsheng Yin's avatar
Liangsheng Yin committed
99
100
101
102
    def pretty_print(self):
        print(f"\x1b[{self.color}m", end="")
        print("-" * self.indent * 2, end="")
        print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
Lianmin Zheng's avatar
Lianmin Zheng committed
103
104


Liangsheng Yin's avatar
Liangsheng Yin committed
105
106
107
108
def mark_start(name, interval=0.1, color=0, indent=0):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
109
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
110
111
112
    if time_infos.get(name, None) is None:
        time_infos[name] = TimeInfo(name, interval, color, indent)
    time_infos[name].acc_time -= time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
113
114


Liangsheng Yin's avatar
Liangsheng Yin committed
115
116
117
118
def mark_end(name):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
119
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
120
121
122
    time_infos[name].acc_time += time.time()
    if time_infos[name].check():
        time_infos[name].pretty_print()
Lianmin Zheng's avatar
Lianmin Zheng committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143


def calculate_time(show=False, min_cost_ms=0.0):
    def wrapper(func):
        def inner_func(*args, **kwargs):
            torch.cuda.synchronize()
            if show:
                start_time = time.time()
            result = func(*args, **kwargs)
            torch.cuda.synchronize()
            if show:
                cost_time = (time.time() - start_time) * 1000
                if cost_time > min_cost_ms:
                    print(f"Function {func.__name__} took {cost_time} ms to run.")
            return result

        return inner_func

    return wrapper


Zhang, Liangang's avatar
Zhang, Liangang committed
144
def get_available_gpu_memory(device, gpu_id, distributed=False):
Lianmin Zheng's avatar
Lianmin Zheng committed
145
146
147
148
    """
    Get available memory for cuda:gpu_id device.
    When distributed is True, the available memory is the minimum available memory of all GPUs.
    """
Zhang, Liangang's avatar
Zhang, Liangang committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    if device == "cuda":
        num_gpus = torch.cuda.device_count()
        assert gpu_id < num_gpus

        if torch.cuda.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
                "which may cause useless memory allocation for torch CUDA context.",
            )

        torch.cuda.empty_cache()
        free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)

    elif device == "xpu":
        num_gpus = torch.xpu.device_count()
        assert gpu_id < num_gpus

        if torch.xpu.current_device() != gpu_id:
            print(
                f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
                "which may cause useless memory allocation for torch XPU context.",
            )
        torch.xpu.empty_cache()
        used_memory = torch.xpu.memory_allocated()
        total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
        free_gpu_memory = total_gpu_memory - used_memory
Lianmin Zheng's avatar
Lianmin Zheng committed
175
176
177

    if distributed:
        tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
Zhang, Liangang's avatar
Zhang, Liangang committed
178
            torch.device(device, gpu_id)
Lianmin Zheng's avatar
Lianmin Zheng committed
179
180
181
182
183
184
185
        )
        torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
        free_gpu_memory = tensor.item()

    return free_gpu_memory / (1 << 30)


Lianmin Zheng's avatar
Lianmin Zheng committed
186
def set_random_seed(seed: int) -> None:
187
    """Set the random seed for all libraries."""
Lianmin Zheng's avatar
Lianmin Zheng committed
188
    random.seed(seed)
189
    np.random.seed(seed)
Lianmin Zheng's avatar
Lianmin Zheng committed
190
191
192
193
194
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


195
def is_port_available(port):
196
    """Return whether a port is available."""
197
198
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
199
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
200
            s.bind(("", port))
201
            s.listen(1)
202
203
204
205
206
            return True
        except socket.error:
            return False


207
208
209
210
211
212
def is_multimodal_model(model_architectures):
    if (
        "LlavaLlamaForCausalLM" in model_architectures
        or "LlavaQwenForCausalLM" in model_architectures
        or "LlavaMistralForCausalLM" in model_architectures
        or "LlavaVidForCausalLM" in model_architectures
213
        or "MllamaForConditionalGeneration" in model_architectures
Yineng Zhang's avatar
Yineng Zhang committed
214
        or "Qwen2VLForConditionalGeneration" in model_architectures
215
216
217
218
    ):
        return True
    else:
        return False
Yuanhan Zhang's avatar
Yuanhan Zhang committed
219
220


221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def is_attention_free_model(model_architectures):
    return False


def model_has_inner_state(model_architectures):
    return False


def is_embedding_model(model_architectures):
    if (
        "LlamaEmbeddingModel" in model_architectures
        or "MistralModel" in model_architectures
        or "LlamaForSequenceClassification" in model_architectures
        or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
    ):
        return True
    else:
        return False


241
242
243
244
245
def is_generation_model(model_architectures, is_embedding: bool = False):
    # We have two ways to determine whether a model is a generative model.
    # 1. Check the model architectue
    # 2. check the `is_embedding` server args

246
247
248
    if (
        "LlamaEmbeddingModel" in model_architectures
        or "MistralModel" in model_architectures
249
250
        or "LlamaForSequenceClassification" in model_architectures
        or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
251
252
    ):
        return False
253
254
    else:
        return not is_embedding
255
256


Yuanhan Zhang's avatar
Yuanhan Zhang committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def decode_video_base64(video_base64):
    from PIL import Image

    # Decode the base64 string
    video_bytes = base64.b64decode(video_base64)

    # Placeholder for the start indices of each PNG image
    img_starts = []

    frame_format = "PNG"  # str(os.getenv('FRAME_FORMAT', "JPEG"))

    assert frame_format in [
        "PNG",
        "JPEG",
    ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"

    if frame_format == "PNG":
        # Find each PNG start signature to isolate images
        i = 0
        while i < len(video_bytes) - 7:  # Adjusted for the length of the PNG signature
            # Check if we found the start of a PNG file
            if (
                video_bytes[i] == 0x89
                and video_bytes[i + 1] == 0x50
                and video_bytes[i + 2] == 0x4E
                and video_bytes[i + 3] == 0x47
                and video_bytes[i + 4] == 0x0D
                and video_bytes[i + 5] == 0x0A
                and video_bytes[i + 6] == 0x1A
                and video_bytes[i + 7] == 0x0A
            ):
                img_starts.append(i)
                i += 8  # Skip the PNG signature
            else:
                i += 1
    else:
        # Find each JPEG start (0xFFD8) to isolate images
        i = 0
        while (
            i < len(video_bytes) - 1
        ):  # Adjusted for the length of the JPEG SOI signature
            # Check if we found the start of a JPEG file
            if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
                img_starts.append(i)
                # Move to the next byte to continue searching for the next image start
                i += 2
            else:
                i += 1

    frames = []
    for start_idx in img_starts:
        # Assuming each image is back-to-back, the end of one image is the start of another
        # The last image goes until the end of the byte string
        end_idx = (
            img_starts[img_starts.index(start_idx) + 1]
            if img_starts.index(start_idx) + 1 < len(img_starts)
            else len(video_bytes)
        )
        img_bytes = video_bytes[start_idx:end_idx]

        # Convert bytes to a PIL Image
        img = Image.open(BytesIO(img_bytes))

        # Convert PIL Image to a NumPy array
        frame = np.array(img)

        # Append the frame to the list of frames
        frames.append(frame)

    # Ensure there's at least one frame to avoid errors with np.stack
    if frames:
        return np.stack(frames, axis=0), img.size
    else:
        return np.array([]), (
            0,
            0,
        )  # Return an empty array and size tuple if no frames were found
Lianmin Zheng's avatar
Lianmin Zheng committed
334
335


336
def load_image(image_file: Union[str, bytes]):
Lianmin Zheng's avatar
Lianmin Zheng committed
337
338
    from PIL import Image

Yuanhan Zhang's avatar
Yuanhan Zhang committed
339
    image = image_size = None
Lianmin Zheng's avatar
Lianmin Zheng committed
340

341
342
343
    if isinstance(image_file, bytes):
        image = Image.open(BytesIO(image_file))
    elif image_file.startswith("http://") or image_file.startswith("https://"):
Lianmin Zheng's avatar
Lianmin Zheng committed
344
345
346
347
348
349
        timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
        response = requests.get(image_file, timeout=timeout)
        image = Image.open(BytesIO(response.content))
    elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
        image = Image.open(image_file)
    elif image_file.startswith("data:"):
350
        image_file = image_file.split(",")[1]
Lianmin Zheng's avatar
Lianmin Zheng committed
351
        image = Image.open(BytesIO(base64.b64decode(image_file)))
Yuanhan Zhang's avatar
Yuanhan Zhang committed
352
353
354
    elif image_file.startswith("video:"):
        image_file = image_file.replace("video:", "")
        image, image_size = decode_video_base64(image_file)
355
    elif isinstance(image_file, str):
Lianmin Zheng's avatar
Lianmin Zheng committed
356
        image = Image.open(BytesIO(base64.b64decode(image_file)))
357
358
    else:
        raise ValueError(f"Invalid image: {image}")
Lianmin Zheng's avatar
Lianmin Zheng committed
359

Yuanhan Zhang's avatar
Yuanhan Zhang committed
360
    return image, image_size
361
362


363
364
365
366
367
def suppress_other_loggers():
    from vllm.logger import logger as vllm_default_logger

    vllm_default_logger.setLevel(logging.WARN)
    logging.getLogger("vllm.config").setLevel(logging.ERROR)
368
369
370
    logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
        logging.WARN
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
371
372
373
    logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
        logging.WARN
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
374
    logging.getLogger("vllm.selector").setLevel(logging.WARN)
375
    logging.getLogger("vllm.utils").setLevel(logging.ERROR)
376

377
378
379
380
    warnings.filterwarnings(
        "ignore", category=UserWarning, message="The given NumPy array is not writable"
    )

381

382
def assert_pkg_version(pkg: str, min_version: str, message: str):
383
384
385
386
    try:
        installed_version = version(pkg)
        if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
            raise Exception(
387
                f"{pkg} is installed with version {installed_version}, which "
Ying Sheng's avatar
Ying Sheng committed
388
                f"is less than the minimum required version {min_version}. " + message
389
390
            )
    except PackageNotFoundError:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
391
        raise Exception(
Ying Sheng's avatar
Ying Sheng committed
392
393
            f"{pkg} with minimum required version {min_version} is not installed. "
            + message
Yuanhan Zhang's avatar
Yuanhan Zhang committed
394
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
395
396


397
398
399
400
def kill_parent_process():
    """Kill the parent process and all children of the parent process."""
    current_process = psutil.Process()
    parent_process = current_process.parent()
401
    kill_child_process(parent_process.pid, skip_pid=current_process.pid)
402
403


404
405
def kill_child_process(pid, including_parent=True, skip_pid=None):
    """Kill the process and all its children process."""
406
407
408
409
410
411
412
    try:
        parent = psutil.Process(pid)
    except psutil.NoSuchProcess:
        return

    children = parent.children(recursive=True)
    for child in children:
413
414
        if child.pid == skip_pid:
            continue
415
416
417
418
419
420
421
422
423
424
425
426
        try:
            child.kill()
        except psutil.NoSuchProcess:
            pass

    if including_parent:
        try:
            parent.kill()
        except psutil.NoSuchProcess:
            pass


427
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
428
429
430
431
432
    """
    Monkey patch the slow p2p access check in vllm.
    NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
    """

433
    import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
Liangsheng Yin's avatar
Liangsheng Yin committed
434

435
    setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
436
437


438
439
440
441
442
443
def monkey_patch_vllm_dummy_weight_loader():
    """
    Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
    """

    from vllm.model_executor.model_loader.loader import (
Ying Sheng's avatar
Ying Sheng committed
444
445
446
447
448
449
450
451
452
453
454
        CacheConfig,
        DeviceConfig,
        DummyModelLoader,
        LoRAConfig,
        ModelConfig,
        ParallelConfig,
        SchedulerConfig,
        _initialize_model,
        initialize_dummy_weights,
        nn,
        set_default_torch_dtype,
455
456
    )

Ying Sheng's avatar
Ying Sheng committed
457
458
459
460
461
462
463
464
465
466
    def load_model(
        self,
        *,
        model_config: ModelConfig,
        device_config: DeviceConfig,
        lora_config: Optional[LoRAConfig],
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
    ) -> nn.Module:
467
468
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
Ying Sheng's avatar
Ying Sheng committed
469
470
471
472
473
474
                model = _initialize_model(
                    model_config,
                    self.load_config,
                    lora_config,
                    cache_config,
                )
475
476
477
478
479
480
481
482
483
484
485
486
487
488

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    quant_method.process_weights_after_loading(module)

            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
        return model.eval()

    setattr(DummyModelLoader, "load_model", load_model)


489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
vllm_all_gather_backup = None


def monkey_patch_vllm_all_gather(reverse: bool = False):
    """Monkey patch all-gather to remove in-place operations."""
    from torch.distributed import _functional_collectives as funcol
    from vllm.distributed.parallel_state import GroupCoordinator

    global vllm_all_gather_backup
    if vllm_all_gather_backup is None:
        vllm_all_gather_backup = GroupCoordinator.all_gather

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert (
            -input_.dim() <= dim < input_.dim()
        ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        input_size = input_.size()
        # Allocate output tensor.
        output_tensor = torch.empty(
            (world_size,) + input_size, dtype=input_.dtype, device=input_.device
        )

        output_tensor = funcol.all_gather_tensor(
            input_, gather_dim=0, group=self.device_group
        ).view((world_size,) + input_size)

        # Reshape
        output_tensor = output_tensor.movedim(0, dim)
        output_tensor = output_tensor.reshape(
            input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
        )
        return output_tensor

    if reverse:
        setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup)
    else:
        setattr(GroupCoordinator, "all_gather", all_gather)


535
536
537
538
539
540
def maybe_set_triton_cache_manager() -> None:
    """Set environment variable to tell Triton to use a
    custom cache manager"""
    cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
    if cache_manger is None:
        manager = "sglang.srt.utils:CustomCacheManager"
541
        logger.debug("Setting Triton cache manager to: %s", manager)
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        os.environ["TRITON_CACHE_MANAGER"] = manager


class CustomCacheManager(FileCacheManager):
    # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
    def __init__(self, key, override=False, dump=False):

        self.key = key
        self.lock_path = None
        if dump:
            self.cache_dir = default_dump_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
            self.lock_path = os.path.join(self.cache_dir, "lock")
            os.makedirs(self.cache_dir, exist_ok=True)
        elif override:
            self.cache_dir = default_override_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
        else:
            # create cache directory if it doesn't exist
            self.cache_dir = (
                os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
            )
            if self.cache_dir:
                self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
                self.cache_dir = os.path.join(self.cache_dir, self.key)
                self.lock_path = os.path.join(self.cache_dir, "lock")
                os.makedirs(self.cache_dir, exist_ok=True)
            else:
                raise RuntimeError("Could not create or locate cache dir")


573
574
575
576
577
578
579
580
def set_ulimit(target_soft_limit=65535):
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
        except ValueError as e:
Lianmin Zheng's avatar
Lianmin Zheng committed
581
            logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
582
583


584
def add_api_key_middleware(app, api_key: str):
585
586
587
588
589
590
591
    @app.middleware("http")
    async def authentication(request, call_next):
        if request.method == "OPTIONS":
            return await call_next(request)
        if request.url.path.startswith("/health"):
            return await call_next(request)
        if request.headers.get("Authorization") != "Bearer " + api_key:
592
            return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
593
        return await call_next(request)
594
595


596
def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
597
598
599
600
    if "SGLANG_USE_MODELSCOPE" in os.environ:
        if not os.path.exists(model_path):
            from modelscope import snapshot_download

601
602
            model_path = snapshot_download(model_path)
            tokenizer_path = snapshot_download(
603
604
                tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
            )
605
    return model_path, tokenizer_path
606
607
608
609


def configure_logger(server_args, prefix: str = ""):
    format = f"[%(asctime)s{prefix}] %(message)s"
Lianmin Zheng's avatar
Lianmin Zheng committed
610
    # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
611
612
613
    logging.basicConfig(
        level=getattr(logging, server_args.log_level.upper()),
        format=format,
614
        datefmt="%Y-%m-%d %H:%M:%S",
615
616
        force=True,
    )
617
618
619
620
621
622
623
624
625
626
627


# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
def replace_submodule(
    model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
    """Replace a submodule in a model with a new module."""
    parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
    target_name = module_name.split(".")[-1]
    setattr(parent, target_name, new_module)
    return new_module
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647


def set_weight_attrs(
    weight: torch.Tensor,
    weight_attrs: Optional[Dict[str, Any]],
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
        assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
        setattr(weight, key, value)
648
649
650


def broadcast_pyobj(
651
652
653
    data: List[Any],
    rank: int,
    dist_group: Optional[torch.distributed.ProcessGroup] = None,
654
655
656
657
658
659
660
661
662
663
):
    """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""

    if rank == 0:
        if len(data) == 0:
            tensor_size = torch.tensor([0], dtype=torch.long)
            dist.broadcast(tensor_size, src=0, group=dist_group)
        else:
            serialized_data = pickle.dumps(data)
            size = len(serialized_data)
664
665
666
            tensor_data = torch.ByteTensor(
                np.frombuffer(serialized_data, dtype=np.uint8)
            )
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
            tensor_size = torch.tensor([size], dtype=torch.long)

            dist.broadcast(tensor_size, src=0, group=dist_group)
            dist.broadcast(tensor_data, src=0, group=dist_group)
        return data
    else:
        tensor_size = torch.tensor([0], dtype=torch.long)
        dist.broadcast(tensor_size, src=0, group=dist_group)
        size = tensor_size.item()

        if size == 0:
            return []

        tensor_data = torch.empty(size, dtype=torch.uint8)
        dist.broadcast(tensor_data, src=0, group=dist_group)

683
        serialized_data = bytes(tensor_data.cpu().numpy())
684
685
        data = pickle.loads(serialized_data)
        return data
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716


step_counter = 0


def pytorch_profile(name, func, *args, data_size=-1):
    """
    Args:
        name (string): the name of recorded function.
        func: the function to be profiled.
        args: the arguments of the profiled function.
        data_size (int): some measurement of the computation complexity.
            Usually, it could be the batch size.
    """
    global step_counter
    os.makedirs("trace", exist_ok=True)
    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
        # on_trace_ready=tensorboard_trace_handler('./log_dir'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    ) as prof:
        with record_function(name):
            with open(f"trace/size_{step_counter}.json", "w") as f:
                json.dump({"size": data_size}, f)
            result = func(*args)
    prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
    step_counter += 1
    return result
717
718
719
720
721
722
723


def first_rank_print(*args, **kwargs):
    if torch.cuda.current_device() == 0:
        print(*args, **kwargs)
    else:
        pass
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739


def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
    socket = context.socket(socket_type)
    if socket_type == zmq.PUSH:
        socket.setsockopt(zmq.SNDHWM, 0)
        socket.setsockopt(zmq.SNDBUF, 100000000)
        socket.connect(f"ipc://{endpoint}")
    elif socket_type == zmq.PULL:
        socket.setsockopt(zmq.RCVHWM, 0)
        socket.setsockopt(zmq.RCVBUF, 100000000)
        socket.bind(f"ipc://{endpoint}")
    else:
        raise ValueError(f"Unsupported socket type: {socket_type}")

    return socket