utils.py 22.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

16
17
"""Common utilities."""

Lianmin Zheng's avatar
Lianmin Zheng committed
18
import base64
19
import fcntl
20
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
21
22
import os
import random
23
import resource
Lianmin Zheng's avatar
Lianmin Zheng committed
24
import socket
25
import struct
Lianmin Zheng's avatar
Lianmin Zheng committed
26
import time
27
from importlib.metadata import PackageNotFoundError, version
Lianmin Zheng's avatar
Lianmin Zheng committed
28
from io import BytesIO
Lianmin Zheng's avatar
Lianmin Zheng committed
29
from typing import List, Optional
Lianmin Zheng's avatar
Lianmin Zheng committed
30
31

import numpy as np
32
import psutil
Lianmin Zheng's avatar
Lianmin Zheng committed
33
34
import requests
import torch
35
import torch.distributed as dist
36
from fastapi.responses import JSONResponse
37
from packaging import version as pkg_version
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from starlette.middleware.base import BaseHTTPMiddleware
39
from torch.nn.parameter import Parameter
40
41
42
43
44
45
from triton.runtime.cache import (
    FileCacheManager,
    default_cache_dir,
    default_dump_dir,
    default_override_dir,
)
46

47
48
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
49

Liangsheng Yin's avatar
Liangsheng Yin committed
50
51
show_time_cost = False
time_infos = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
52
53


Liangsheng Yin's avatar
Liangsheng Yin committed
54
55
56
57
def enable_show_time_cost():
    global show_time_cost
    show_time_cost = True

Lianmin Zheng's avatar
Lianmin Zheng committed
58

Liangsheng Yin's avatar
Liangsheng Yin committed
59
60
61
62
63
64
class TimeInfo:
    def __init__(self, name, interval=0.1, color=0, indent=0):
        self.name = name
        self.interval = interval
        self.color = color
        self.indent = indent
Lianmin Zheng's avatar
Lianmin Zheng committed
65

Liangsheng Yin's avatar
Liangsheng Yin committed
66
67
        self.acc_time = 0
        self.last_acc_time = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
68

Liangsheng Yin's avatar
Liangsheng Yin committed
69
70
71
72
73
    def check(self):
        if self.acc_time - self.last_acc_time > self.interval:
            self.last_acc_time = self.acc_time
            return True
        return False
Lianmin Zheng's avatar
Lianmin Zheng committed
74

Liangsheng Yin's avatar
Liangsheng Yin committed
75
76
77
78
    def pretty_print(self):
        print(f"\x1b[{self.color}m", end="")
        print("-" * self.indent * 2, end="")
        print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m")
Lianmin Zheng's avatar
Lianmin Zheng committed
79
80


Liangsheng Yin's avatar
Liangsheng Yin committed
81
82
83
84
def mark_start(name, interval=0.1, color=0, indent=0):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
85
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
86
87
88
    if time_infos.get(name, None) is None:
        time_infos[name] = TimeInfo(name, interval, color, indent)
    time_infos[name].acc_time -= time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
89
90


Liangsheng Yin's avatar
Liangsheng Yin committed
91
92
93
94
def mark_end(name):
    global time_infos, show_time_cost
    if not show_time_cost:
        return
Lianmin Zheng's avatar
Lianmin Zheng committed
95
    torch.cuda.synchronize()
Liangsheng Yin's avatar
Liangsheng Yin committed
96
97
98
    time_infos[name].acc_time += time.time()
    if time_infos[name].check():
        time_infos[name].pretty_print()
Lianmin Zheng's avatar
Lianmin Zheng committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119


def calculate_time(show=False, min_cost_ms=0.0):
    def wrapper(func):
        def inner_func(*args, **kwargs):
            torch.cuda.synchronize()
            if show:
                start_time = time.time()
            result = func(*args, **kwargs)
            torch.cuda.synchronize()
            if show:
                cost_time = (time.time() - start_time) * 1000
                if cost_time > min_cost_ms:
                    print(f"Function {func.__name__} took {cost_time} ms to run.")
            return result

        return inner_func

    return wrapper


120
def get_available_gpu_memory(gpu_id, distributed=False):
Lianmin Zheng's avatar
Lianmin Zheng committed
121
122
123
124
125
126
127
128
129
130
131
132
133
    """
    Get available memory for cuda:gpu_id device.
    When distributed is True, the available memory is the minimum available memory of all GPUs.
    """
    num_gpus = torch.cuda.device_count()
    assert gpu_id < num_gpus

    if torch.cuda.current_device() != gpu_id:
        print(
            f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
            "which may cause useless memory allocation for torch CUDA context.",
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
134
    torch.cuda.empty_cache()
Lianmin Zheng's avatar
Lianmin Zheng committed
135
136
137
138
139
140
141
142
143
144
145
146
    free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)

    if distributed:
        tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
            torch.device("cuda", gpu_id)
        )
        torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
        free_gpu_memory = tensor.item()

    return free_gpu_memory / (1 << 30)


Lianmin Zheng's avatar
Lianmin Zheng committed
147
def set_random_seed(seed: int) -> None:
148
    """Set the random seed for all libraries."""
Lianmin Zheng's avatar
Lianmin Zheng committed
149
    random.seed(seed)
150
    np.random.seed(seed)
Lianmin Zheng's avatar
Lianmin Zheng committed
151
152
153
154
155
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


156
def is_port_available(port):
157
    """Return whether a port is available."""
158
159
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        try:
160
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
161
            s.bind(("", port))
162
            s.listen(1)
163
164
165
166
167
            return True
        except socket.error:
            return False


Lianmin Zheng's avatar
Lianmin Zheng committed
168
def allocate_init_ports(
Lianmin Zheng's avatar
Lianmin Zheng committed
169
170
    port: Optional[int] = None,
    additional_ports: Optional[List[int]] = None,
171
    dp_size: int = 1,
Lianmin Zheng's avatar
Lianmin Zheng committed
172
):
173
    """Allocate ports for all connections."""
174
175
176
177
178
179
180
181
    if additional_ports:
        ret_ports = [port] + additional_ports
    else:
        ret_ports = [port]

    ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
    cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000

Mingyi's avatar
Mingyi committed
182
183
    # HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
    num_ports_needed = 4 + dp_size
184
    while len(ret_ports) < num_ports_needed:
185
186
187
188
        if cur_port not in ret_ports and is_port_available(cur_port):
            ret_ports.append(cur_port)
        cur_port += 1

189
    if port is not None and ret_ports[0] != port:
190
191
192
        logger.warn(
            f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
193

194
    return ret_ports[0], ret_ports[1:num_ports_needed]
195

Lianmin Zheng's avatar
Lianmin Zheng committed
196

Lianmin Zheng's avatar
Lianmin Zheng committed
197
def get_int_token_logit_bias(tokenizer, vocab_size):
198
    """Get the logit bias for integer-only tokens."""
199
200
    # a bug when model's vocab size > tokenizer.vocab_size
    vocab_size = tokenizer.vocab_size
Lianmin Zheng's avatar
Lianmin Zheng committed
201
202
    logit_bias = np.zeros(vocab_size, dtype=np.float32)
    for t_id in range(vocab_size):
203
        ss = tokenizer.decode([t_id]).strip()
Lianmin Zheng's avatar
Lianmin Zheng committed
204
205
206
207
208
209
210
211
212
        if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
            logit_bias[t_id] = -1e5

    return logit_bias


def is_multimodal_model(model):
    from sglang.srt.model_config import ModelConfig

Yuanhan Zhang's avatar
Yuanhan Zhang committed
213
214
215
216
    if isinstance(model, str):
        model = model.lower()
        return "llava" in model or "yi-vl" in model or "llava-next" in model

Lianmin Zheng's avatar
Lianmin Zheng committed
217
    if isinstance(model, ModelConfig):
Christopher Chou's avatar
Christopher Chou committed
218
        model_path = model.path.lower()
Liangsheng Yin's avatar
Liangsheng Yin committed
219
220
221
        return (
            "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    raise ValueError("unrecognized type")


def decode_video_base64(video_base64):
    from PIL import Image

    # Decode the base64 string
    video_bytes = base64.b64decode(video_base64)

    # Placeholder for the start indices of each PNG image
    img_starts = []

    frame_format = "PNG"  # str(os.getenv('FRAME_FORMAT', "JPEG"))

    assert frame_format in [
        "PNG",
        "JPEG",
    ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"

    if frame_format == "PNG":
        # Find each PNG start signature to isolate images
        i = 0
        while i < len(video_bytes) - 7:  # Adjusted for the length of the PNG signature
            # Check if we found the start of a PNG file
            if (
                video_bytes[i] == 0x89
                and video_bytes[i + 1] == 0x50
                and video_bytes[i + 2] == 0x4E
                and video_bytes[i + 3] == 0x47
                and video_bytes[i + 4] == 0x0D
                and video_bytes[i + 5] == 0x0A
                and video_bytes[i + 6] == 0x1A
                and video_bytes[i + 7] == 0x0A
            ):
                img_starts.append(i)
                i += 8  # Skip the PNG signature
            else:
                i += 1
    else:
        # Find each JPEG start (0xFFD8) to isolate images
        i = 0
        while (
            i < len(video_bytes) - 1
        ):  # Adjusted for the length of the JPEG SOI signature
            # Check if we found the start of a JPEG file
            if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
                img_starts.append(i)
                # Move to the next byte to continue searching for the next image start
                i += 2
            else:
                i += 1

    frames = []
    for start_idx in img_starts:
        # Assuming each image is back-to-back, the end of one image is the start of another
        # The last image goes until the end of the byte string
        end_idx = (
            img_starts[img_starts.index(start_idx) + 1]
            if img_starts.index(start_idx) + 1 < len(img_starts)
            else len(video_bytes)
        )
        img_bytes = video_bytes[start_idx:end_idx]

        # Convert bytes to a PIL Image
        img = Image.open(BytesIO(img_bytes))

        # Convert PIL Image to a NumPy array
        frame = np.array(img)

        # Append the frame to the list of frames
        frames.append(frame)

    # Ensure there's at least one frame to avoid errors with np.stack
    if frames:
        return np.stack(frames, axis=0), img.size
    else:
        return np.array([]), (
            0,
            0,
        )  # Return an empty array and size tuple if no frames were found
Lianmin Zheng's avatar
Lianmin Zheng committed
303
304
305
306
307


def load_image(image_file):
    from PIL import Image

Yuanhan Zhang's avatar
Yuanhan Zhang committed
308
    image = image_size = None
Lianmin Zheng's avatar
Lianmin Zheng committed
309
310
311
312
313
314
315
316

    if image_file.startswith("http://") or image_file.startswith("https://"):
        timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
        response = requests.get(image_file, timeout=timeout)
        image = Image.open(BytesIO(response.content))
    elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
        image = Image.open(image_file)
    elif image_file.startswith("data:"):
317
        image_file = image_file.split(",")[1]
Lianmin Zheng's avatar
Lianmin Zheng committed
318
        image = Image.open(BytesIO(base64.b64decode(image_file)))
Yuanhan Zhang's avatar
Yuanhan Zhang committed
319
320
321
    elif image_file.startswith("video:"):
        image_file = image_file.replace("video:", "")
        image, image_size = decode_video_base64(image_file)
Lianmin Zheng's avatar
Lianmin Zheng committed
322
323
324
    else:
        image = Image.open(BytesIO(base64.b64decode(image_file)))

Yuanhan Zhang's avatar
Yuanhan Zhang committed
325
    return image, image_size
326
327


328
329
330
331
332
def suppress_other_loggers():
    from vllm.logger import logger as vllm_default_logger

    vllm_default_logger.setLevel(logging.WARN)
    logging.getLogger("vllm.config").setLevel(logging.ERROR)
333
334
335
    logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
        logging.WARN
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
336
337
338
    logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
        logging.WARN
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
339
340
    logging.getLogger("vllm.selector").setLevel(logging.WARN)
    logging.getLogger("vllm.utils").setLevel(logging.WARN)
341
342


343
def assert_pkg_version(pkg: str, min_version: str, message: str):
344
345
346
347
    try:
        installed_version = version(pkg)
        if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
            raise Exception(
348
                f"{pkg} is installed with version {installed_version}, which "
Ying Sheng's avatar
Ying Sheng committed
349
                f"is less than the minimum required version {min_version}. " + message
350
351
            )
    except PackageNotFoundError:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
352
        raise Exception(
Ying Sheng's avatar
Ying Sheng committed
353
354
            f"{pkg} with minimum required version {min_version} is not installed. "
            + message
Yuanhan Zhang's avatar
Yuanhan Zhang committed
355
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
356
357


358
359
360
361
def kill_parent_process():
    """Kill the parent process and all children of the parent process."""
    current_process = psutil.Process()
    parent_process = current_process.parent()
Ke Bao's avatar
Ke Bao committed
362
    children = parent_process.children(recursive=True)
363
364
365
366
367
368
    for child in children:
        if child.pid != current_process.pid:
            os.kill(child.pid, 9)
    os.kill(parent_process.pid, 9)


369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def kill_child_process(pid, including_parent=True):
    try:
        parent = psutil.Process(pid)
    except psutil.NoSuchProcess:
        return

    children = parent.children(recursive=True)
    for child in children:
        try:
            child.kill()
        except psutil.NoSuchProcess:
            pass

    if including_parent:
        try:
            parent.kill()
        except psutil.NoSuchProcess:
            pass


389
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
390
391
392
393
394
    """
    Monkey patch the slow p2p access check in vllm.
    NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
    """

395
    import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
Liangsheng Yin's avatar
Liangsheng Yin committed
396

397
    setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
398
399


400
401
402
403
404
405
def monkey_patch_vllm_dummy_weight_loader():
    """
    Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
    """

    from vllm.model_executor.model_loader.loader import (
Ying Sheng's avatar
Ying Sheng committed
406
407
408
409
410
        CacheConfig,
        DeviceConfig,
        DummyModelLoader,
        LoRAConfig,
        ModelConfig,
411
        MultiModalConfig,
Ying Sheng's avatar
Ying Sheng committed
412
413
414
415
416
417
        ParallelConfig,
        SchedulerConfig,
        _initialize_model,
        initialize_dummy_weights,
        nn,
        set_default_torch_dtype,
418
419
    )

Ying Sheng's avatar
Ying Sheng committed
420
421
422
423
424
425
    def load_model(
        self,
        *,
        model_config: ModelConfig,
        device_config: DeviceConfig,
        lora_config: Optional[LoRAConfig],
426
        multimodal_config: Optional[MultiModalConfig],
Ying Sheng's avatar
Ying Sheng committed
427
428
429
430
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
    ) -> nn.Module:
431
432
        with set_default_torch_dtype(model_config.dtype):
            with torch.device(device_config.device):
Ying Sheng's avatar
Ying Sheng committed
433
434
435
436
                model = _initialize_model(
                    model_config,
                    self.load_config,
                    lora_config,
437
                    multimodal_config,
Ying Sheng's avatar
Ying Sheng committed
438
439
                    cache_config,
                )
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457

            for _, module in model.named_modules():
                quant_method = getattr(module, "quant_method", None)
                if quant_method is not None:
                    quant_method.process_weights_after_loading(module)
                # FIXME: Remove this after Mixtral is updated
                # to use quant_method.
                if hasattr(module, "process_weights_after_loading"):
                    module.process_weights_after_loading()

            # NOTE(woosuk): For accurate performance evaluation, we assign
            # random values to the weights.
            initialize_dummy_weights(model)
        return model.eval()

    setattr(DummyModelLoader, "load_model", load_model)


458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
vllm_all_gather_backup = None


def monkey_patch_vllm_all_gather(reverse: bool = False):
    """Monkey patch all-gather to remove in-place operations."""
    from torch.distributed import _functional_collectives as funcol
    from vllm.distributed.parallel_state import GroupCoordinator

    global vllm_all_gather_backup
    if vllm_all_gather_backup is None:
        vllm_all_gather_backup = GroupCoordinator.all_gather

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert (
            -input_.dim() <= dim < input_.dim()
        ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        input_size = input_.size()
        # Allocate output tensor.
        output_tensor = torch.empty(
            (world_size,) + input_size, dtype=input_.dtype, device=input_.device
        )

        output_tensor = funcol.all_gather_tensor(
            input_, gather_dim=0, group=self.device_group
        ).view((world_size,) + input_size)

        # Reshape
        output_tensor = output_tensor.movedim(0, dim)
        output_tensor = output_tensor.reshape(
            input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
        )
        return output_tensor

    if reverse:
        setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup)
    else:
        setattr(GroupCoordinator, "all_gather", all_gather)


504
505
506
507
508
509
def maybe_set_triton_cache_manager() -> None:
    """Set environment variable to tell Triton to use a
    custom cache manager"""
    cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
    if cache_manger is None:
        manager = "sglang.srt.utils:CustomCacheManager"
510
        logger.debug("Setting Triton cache manager to: %s", manager)
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
        os.environ["TRITON_CACHE_MANAGER"] = manager


class CustomCacheManager(FileCacheManager):
    # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
    def __init__(self, key, override=False, dump=False):

        self.key = key
        self.lock_path = None
        if dump:
            self.cache_dir = default_dump_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
            self.lock_path = os.path.join(self.cache_dir, "lock")
            os.makedirs(self.cache_dir, exist_ok=True)
        elif override:
            self.cache_dir = default_override_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
        else:
            # create cache directory if it doesn't exist
            self.cache_dir = (
                os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
            )
            if self.cache_dir:
                self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
                self.cache_dir = os.path.join(self.cache_dir, self.key)
                self.lock_path = os.path.join(self.cache_dir, "lock")
                os.makedirs(self.cache_dir, exist_ok=True)
            else:
                raise RuntimeError("Could not create or locate cache dir")


Lianmin Zheng's avatar
Lianmin Zheng committed
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
API_KEY_HEADER_NAME = "X-API-Key"


class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
    def __init__(self, app, api_key: str):
        super().__init__(app)
        self.api_key = api_key

    async def dispatch(self, request, call_next):
        # extract API key from the request headers
        api_key_header = request.headers.get(API_KEY_HEADER_NAME)
        if not api_key_header or api_key_header != self.api_key:
            return JSONResponse(
                status_code=403,
                content={"detail": "Invalid API Key"},
            )
        response = await call_next(request)
559
        return response
560
561
562
563
564
565
566
567
568
569
570
571
572


def get_ip_address(ifname):
    """
    Get the IP address of a network interface.

    :param ifname: Name of the network interface (e.g., 'eth0')
    :return: IP address of the network interface
    """
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    ip_address = fcntl.ioctl(
        s.fileno(),
        0x8915,  # SIOCGIFADDR
Ying Sheng's avatar
Ying Sheng committed
573
        struct.pack("256s", bytes(ifname[:15], "utf-8")),
574
575
576
577
578
579
580
    )[20:24]
    return socket.inet_ntoa(ip_address)


def send_addrs_to_rank_0(model_port_args, server_args):
    assert server_args.node_rank != 0 and server_args.dp_size == 1

Ying Sheng's avatar
Ying Sheng committed
581
582
583
    ifname = os.environ.get(
        "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
    )
584
585
586
587
588
    ip_addr = get_ip_address(ifname)

    num_tp_ports = server_args.tp_size // server_args.nnodes
    model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
    ip_addr = [int(x) for x in ip_addr.split(".")]
Ying Sheng's avatar
Ying Sheng committed
589
590
591
    addrs_tensor = torch.tensor(
        ip_addr + model_port_args.model_tp_ports, dtype=torch.int
    )
592
593

    init_method = f"tcp://{server_args.nccl_init_addr}"
Ying Sheng's avatar
Ying Sheng committed
594
595
596
597
598
599
    dist.init_process_group(
        backend="gloo",
        init_method=init_method,
        rank=server_args.node_rank,
        world_size=server_args.nnodes,
    )
600
    dist.send(addrs_tensor, dst=0)
Ying Sheng's avatar
Ying Sheng committed
601
602
603
    print(
        f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
    )
604
605

    dist.barrier()
Ying Sheng's avatar
Ying Sheng committed
606
    dist.destroy_process_group()
607
608
609
610
611


def receive_addrs(model_port_args, server_args):
    assert server_args.node_rank == 0 and server_args.dp_size == 1

Ying Sheng's avatar
Ying Sheng committed
612
613
614
    ifname = os.environ.get(
        "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
    )
615
616
617
618
619
620
    ip_addr = get_ip_address(ifname)

    num_tp_ports = server_args.tp_size // server_args.nnodes
    model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports

    init_method = f"tcp://{server_args.nccl_init_addr}"
Ying Sheng's avatar
Ying Sheng committed
621
622
623
624
625
626
    dist.init_process_group(
        backend="gloo",
        init_method=init_method,
        rank=server_args.node_rank,
        world_size=server_args.nnodes,
    )
627
628
629
630
631
632

    for src_rank in range(1, server_args.nnodes):
        tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
        dist.recv(tensor, src=src_rank)
        ip = ".".join([str(x) for x in tensor[:4].tolist()])
        ports = tensor[4:].tolist()
Ying Sheng's avatar
Ying Sheng committed
633
634
635
636
637
638
        model_port_args.model_tp_ips[
            num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
        ] = [ip] * num_tp_ports
        model_port_args.model_tp_ports[
            num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
        ] = ports
639
640
641
        print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")

    dist.barrier()
Ying Sheng's avatar
Ying Sheng committed
642
    dist.destroy_process_group()
643
644
645
646
647
648
649
650
651
652
653


def set_ulimit(target_soft_limit=65535):
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
        except ValueError as e:
            logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
654
655
656
657
658
659
660
661
662
663


def is_llama3_405b_fp8(model_config):
    """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
    if (
        model_config.hf_config.architectures[0] == "LlamaForCausalLM"
        and model_config.hf_config.hidden_size == 16384
        and model_config.hf_config.intermediate_size == 53248
        and model_config.hf_config.num_hidden_layers == 126
        and model_config.hf_config.num_key_value_heads == 16
664
        and hasattr(model_config.hf_config, "quantization_config")
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
    ):
        return True
    return False


def monkey_patch_vllm_qvk_linear_loader():
    """A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
    from vllm.model_executor.layers.linear import QKVParallelLinear

    origin_weight_loader = QKVParallelLinear.weight_loader

    def get_original_weight(loaded_weight, head_dim):
        n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
        dim = loaded_weight.shape[1]
        for i in range(n_kv_head):
            loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
                2 * i * head_dim : (2 * i + 1) * head_dim, :
            ]
        original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
        assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
        return original_kv_weight

    def weight_loader_srt(
        self,
        param: Parameter,
        loaded_weight: torch.Tensor,
        loaded_shard_id: Optional[str] = None,
    ):
        if (
            loaded_shard_id in ["k", "v"]
            and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
        ):
            loaded_weight = get_original_weight(loaded_weight, self.head_size)

        origin_weight_loader(self, param, loaded_weight, loaded_shard_id)

    setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)