utils.py 19.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
import dataclasses
9
import os
10
import pickle
11
import socket
12
import sys
13
import time
14
import uuid
15
from collections import deque
16
from collections.abc import Sequence
17
from datetime import timedelta
18
from typing import Any
Zhuohan Li's avatar
Zhuohan Li committed
19
20

import torch
21
from torch.distributed import ProcessGroup, TCPStore
22
23
24
25
26
27
from torch.distributed.distributed_c10d import (
    Backend,
    PrefixStore,
    _get_default_timeout,
    _unregister_process_group,
)
28
from torch.distributed.rendezvous import rendezvous
29

30
31
import vllm.envs as envs
from vllm.logger import init_logger
32
from vllm.utils import get_tcp_uri, is_torch_equal_or_newer
33
34
35

logger = init_logger(__name__)

36
37
38
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
39
40
41
USE_SCHED_YIELD = (sys.version_info[:3] >= (3, 11, 1)) or (
    sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8
)
42
43
44
45
46
47
48
49


def sched_yield():
    if USE_SCHED_YIELD:
        os.sched_yield()
    else:
        time.sleep(0)

Zhuohan Li's avatar
Zhuohan Li committed
50

51
52
53
def ensure_divisibility(numerator, denominator):
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, "{} is not divisible by {}".format(
54
55
        numerator, denominator
    )
56
57
58
59
60
61
62
63


def divide(numerator, denominator):
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator

Zhuohan Li's avatar
Zhuohan Li committed
64
65
66
67
68

def split_tensor_along_last_dim(
    tensor: torch.Tensor,
    num_partitions: int,
    contiguous_split_chunks: bool = False,
69
) -> Sequence[torch.Tensor]:
70
    """Split a tensor along its last dimension.
Zhuohan Li's avatar
Zhuohan Li committed
71

72
73
74
75
76
    Arguments:
        tensor: input tensor.
        num_partitions: number of partitions to split the tensor
        contiguous_split_chunks: If True, make each chunk contiguous
                                 in memory.
Zhuohan Li's avatar
Zhuohan Li committed
77

78
79
    Returns:
        A list of Tensors
Zhuohan Li's avatar
Zhuohan Li committed
80
81
82
83
84
85
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = divide(tensor.size()[last_dim], num_partitions)
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
86
    # NOTE: torch.split does not create contiguous tensors by default.
Zhuohan Li's avatar
Zhuohan Li committed
87
88
89
90
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list
91
92


93
94
95
def get_pp_indices(
    num_hidden_layers: int, pp_rank: int, pp_size: int
) -> tuple[int, int]:
96
    """Try to evenly distribute layers across partitions.
97

98
    If the number of layers is not divisible by the number of partitions,
99
100
101
102
103
104
105
106
107
    the remaining layers are evenly distributed across all but the last
    partition. The last partition is excluded because it often contains an
    additional norm layer and we are attempting to balance compute.

    If `pp_size > 2` and the number of remaining layers is
    `0 < x <= pp_size - 2` then the remaining layers are evenly distributed
    across the middle partitions. The first and last partitions are excluded
    because they contain the input and output embeddings respectively and we
    are attempting to reduce maximum memory consumption across partitions.
108
    """
109
110
111
    partition_list_str = envs.VLLM_PP_LAYER_PARTITION
    if partition_list_str is not None:
        try:
112
            partitions = [int(layer) for layer in partition_list_str.split(",")]
113
        except ValueError as err:
114
115
116
            raise ValueError(
                "Invalid partition string: {}".format(partition_list_str)
            ) from err
117
118
119
        if len(partitions) != pp_size:
            raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
        if sum(partitions) != num_hidden_layers:
120
            raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
121
122
    else:
        layers_per_partition = num_hidden_layers // pp_size
123
124
125
126
127
        partitions = [layers_per_partition for _ in range(pp_size)]

        if remaining_layers := num_hidden_layers % pp_size:
            for i in range(2, remaining_layers + 2):
                partitions[-i] += 1
128
129
130
131
            logger.info(
                "Hidden layers were unevenly partitioned: [%s]. "
                "This can be manually overridden using the "
                "VLLM_PP_LAYER_PARTITION environment variable",
132
133
                ",".join(str(p) for p in partitions),
            )
134
135
136

    start_layer = sum(partitions[:pp_rank])
    end_layer = start_layer + partitions[pp_rank]
137

138
    return (start_layer, end_layer)
139
140


141
142
143
144
145
146
@dataclasses.dataclass
class StatelessProcessGroup:
    """A dataclass to hold a metadata store, and the rank, world_size of the
    group. Only use it to communicate metadata between processes.
    For data-plane communication, create NCCL-related objects.
    """
147

148
149
150
    rank: int
    world_size: int
    store: torch._C._distributed_c10d.Store
151
152

    # stores a reference to the socket so that the file descriptor stays alive
153
    socket: socket.socket | None
154

155
156
157
    data_expiration_seconds: int = 3600  # 1 hour

    # dst rank -> counter
158
    send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
159
    # src rank -> counter
160
    recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
161
    broadcast_send_counter: int = 0
162
    broadcast_recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
163
164

    # A deque to store the data entries, with key and timestamp.
165
    entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque)
166
167
168
169
170

    def __post_init__(self):
        assert self.rank < self.world_size
        self.send_dst_counter = {i: 0 for i in range(self.world_size)}
        self.recv_src_counter = {i: 0 for i in range(self.world_size)}
171
        self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
172
173
174
175

    def send_obj(self, obj: Any, dst: int):
        """Send an object to a destination rank."""
        self.expire_data()
176
        key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        self.store.set(key, pickle.dumps(obj))
        self.send_dst_counter[dst] += 1
        self.entries.append((key, time.time()))

    def expire_data(self):
        """Expire data that is older than `data_expiration_seconds` seconds."""
        while self.entries:
            # check the oldest entry
            key, timestamp = self.entries[0]
            if time.time() - timestamp > self.data_expiration_seconds:
                self.store.delete_key(key)
                self.entries.popleft()
            else:
                break

    def recv_obj(self, src: int) -> Any:
        """Receive an object from a source rank."""
        obj = pickle.loads(
195
196
            self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
        )
197
198
199
        self.recv_src_counter[src] += 1
        return obj

200
    def broadcast_obj(self, obj: Any | None, src: int) -> Any:
201
202
203
204
205
206
        """Broadcast an object from a source rank to all other ranks.
        It does not clean up after all ranks have received the object.
        Use it for limited times, e.g., for initialization.
        """
        if self.rank == src:
            self.expire_data()
207
            key = f"broadcast_from/{src}/{self.broadcast_send_counter}"
208
209
210
211
212
            self.store.set(key, pickle.dumps(obj))
            self.broadcast_send_counter += 1
            self.entries.append((key, time.time()))
            return obj
        else:
213
            key = f"broadcast_from/{src}/{self.broadcast_recv_src_counter[src]}"
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            recv_obj = pickle.loads(self.store.get(key))
            self.broadcast_recv_src_counter[src] += 1
            return recv_obj

    def all_gather_obj(self, obj: Any) -> list[Any]:
        """All gather an object from all ranks."""
        gathered_objs = []
        for i in range(self.world_size):
            if i == self.rank:
                gathered_objs.append(obj)
                self.broadcast_obj(obj, src=self.rank)
            else:
                recv_obj = self.broadcast_obj(None, src=i)
                gathered_objs.append(recv_obj)
        return gathered_objs

230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    def barrier(self, timeout: float = 30.0):
        """A robust barrier to synchronize all ranks.


        Uses a multi-phase approach to ensure all processes reach the barrier
        before proceeding:

        1. Each process signals it has reached the barrier

        2. Each process signals that it has confirmed the arrival of all other
        ranks.

        3. Rank 0 waits for all other ranks to signal their departure to ensure
        that all ranks have departed the barrier first.

        Args:
            timeout: Maximum time in seconds to wait for each phase (in seconds)


        Raises:
            RuntimeError: If coordination fails or times out
        """
        # Generate a barrier ID that is globally unique
        try:
            if self.rank == 0:
                barrier_id = f"barrier_{uuid.uuid4()}"
                self.broadcast_obj(barrier_id, src=0)
            else:
                barrier_id = self.broadcast_obj(None, src=0)
        except Exception as e:
            raise RuntimeError("Failed to broadcast barrier_id") from e

        # Phase 1: Signal arrival at barrier
        # Wait for all processes to arrive
        # We need all ranks to confirm the arrival of all other ranks.
        # This is the key synchronization point.
        arrival_key = f"arrival_{barrier_id}_{self.rank}"
        try:
            self.store.set(arrival_key, b"1")
        except Exception as e:
            raise RuntimeError("Failed to signal barrier arrival") from e

        start_time = time.time()
        processes_arrived: set[int] = set()

        while len(processes_arrived) < self.world_size:
            # Check for timeout
            cur_time = time.time()
            if cur_time - start_time > timeout:
279
                raise RuntimeError("Barrier timed out after %f seconds", timeout)
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325

            # Check for each process
            for i in range(self.world_size):
                if i in processes_arrived:
                    continue

                key = f"arrival_{barrier_id}_{i}"
                try:
                    # Try to get the key - if it exists, we'll get a value
                    # If it doesn't exist, it will throw an exception
                    self.store.get(key)
                    processes_arrived.add(i)
                except KeyError:
                    # Key doesn't exist yet
                    pass
                except Exception as check_e:
                    logger.debug("Error checking key existence: %s", check_e)
                    sched_yield()

            # Short sleep to avoid tight polling
            if len(processes_arrived) < self.world_size:
                sched_yield()

        # Phase 2: Signal departure from barrier
        # We only care to block at this stage in rank 0, which runs the
        # server side of the TCPStore. We want to make sure that all
        # clients have departed the barrier before rank 0 in case the
        # next thing after the barrier is a shutdown, including tearing
        # down the TCPStore. Other ranks can exit the barrier immediately
        # after signaling their departure.
        departure_key = f"departure_{barrier_id}_{self.rank}"
        try:
            self.store.set(departure_key, b"1")
        except Exception as e:
            raise RuntimeError("Failed to signal barrier departure") from e

        if self.rank != 0:
            return

        # Make rank 0 wait for all processes to signal departure
        start_time = time.time()
        processes_departed: set[int] = set()

        while len(processes_departed) < self.world_size:
            # Check for timeout
            if time.time() - start_time > timeout:
326
                raise RuntimeError("Barrier departure timed out after %f s", timeout)
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350

            # Check for each process
            for i in range(self.world_size):
                if i in processes_departed:
                    continue

                key = f"departure_{barrier_id}_{i}"
                try:
                    # Try to get the key - if it exists, we'll get a value
                    # If it doesn't exist, it will throw an exception
                    self.store.get(key)
                    processes_departed.add(i)
                except KeyError:
                    # Key doesn't exist yet
                    pass
                except Exception as check_e:
                    logger.debug("Error checking key existence: %s", check_e)
                    sched_yield()

            # Short sleep to avoid tight polling
            if len(processes_departed) < self.world_size:
                sched_yield()

        # Clean up keys to avoid leaking memory in the store
351
        for i in range(self.world_size):
352
353
354
            try:
                self.store.delete_key(f"arrival_{barrier_id}_{i}")
            except Exception:
355
                logger.debug("Error deleting key: %s", f"arrival_{barrier_id}_{i}")
356
357
358
359

            try:
                self.store.delete_key(f"departure_{barrier_id}_{i}")
            except Exception:
360
                logger.debug("Error deleting key: %s", f"departure_{barrier_id}_{i}")
361
362
363

    @staticmethod
    def create(
364
365
        host: str,
        port: int,
366
367
368
        rank: int,
        world_size: int,
        data_expiration_seconds: int = 3600,
369
        store_timeout: int = 300,
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    ) -> "StatelessProcessGroup":
        """A replacement for `torch.distributed.init_process_group` that does not
        pollute the global state.

        If we have process A and process B called `torch.distributed.init_process_group`
        to form a group, and then we want to form another group with process A, B, C,
        D, it is not possible in PyTorch, because process A and process B have already
        formed a group, and process C and process D cannot join that group. This
        function is a workaround for this issue.

        `torch.distributed.init_process_group` is a global call, while this function
        is a stateless call. It will return a `StatelessProcessGroup` object that can be
        used for exchanging metadata. With this function, process A and process B
        can call `StatelessProcessGroup.create` to form a group, and then process A, B,
        C, and D can call `StatelessProcessGroup.create` to form another group.
385
        """  # noqa
386
387
388
389
390
391
392
393
394
395
396
397
        launch_server = rank == 0
        if launch_server:
            # listen on the specified interface (instead of 0.0.0.0)
            listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            listen_socket.bind((host, port))
            listen_socket.listen()
            listen_fd = listen_socket.fileno()
        else:
            listen_socket = None
            listen_fd = None

398
399
400
401
        store = TCPStore(
            host_name=host,
            port=port,
            world_size=world_size,
402
            is_master=launch_server,
403
            timeout=timedelta(seconds=store_timeout),
404
405
            use_libuv=False,  # for now: github.com/pytorch/pytorch/pull/150215
            master_listen_fd=listen_fd,
406
        )
407
408
409
410
411

        return StatelessProcessGroup(
            rank=rank,
            world_size=world_size,
            store=store,
412
            socket=listen_socket,
413
414
            data_expiration_seconds=data_expiration_seconds,
        )
415
416


417
418
419
420
421
422
423
def init_gloo_process_group(
    backend: Backend,
    prefix_store: PrefixStore,
    group_rank: int,
    group_size: int,
    timeout: timedelta,
) -> ProcessGroup:
424
    """
425
    Stateless init ProcessGroup with gloo backend compatible with
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    different torch versions.
    """
    if is_torch_equal_or_newer("2.6"):
        pg = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
    else:
        options = ProcessGroup.Options(backend=backend)
        pg = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
            options,
        )
    from torch.distributed.distributed_c10d import ProcessGroupGloo
443
444
445
446

    backend_class = ProcessGroupGloo(
        prefix_store, group_rank, group_size, timeout=timeout
    )
447
448
449
450
451
452
453
454
455
456
457
    backend_type = ProcessGroup.BackendType.GLOO
    device = torch.device("cpu")
    if is_torch_equal_or_newer("2.6"):
        # _set_default_backend is supported in torch >= 2.6
        pg._set_default_backend(backend_type)
    backend_class._set_sequence_number_for_group()

    pg._register_backend(device, backend_type, backend_class)
    return pg


458
def stateless_init_torch_distributed_process_group(
459
460
    host: str, port: int, rank: int, world_size: int, backend: str
) -> ProcessGroup:
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    """
    A replacement for `torch.distributed.init_process_group` that does not
    pollute the global state. The created ProcessGroup object can be used for
    some operations such as `allreduce`, because it does not depend on the
    global rank. However, some operations such as `broadcast` cannot be used
    because it depends on the global rank.

    # TODO: ask for help from PyTorch team if we need the `broadcast` operation.

    This function is useful when we are not sure about the total number of
    processes in the process group. For example, we may have process
    1, 2, ..., 8 who want to communicate, and process 9 might be the same
    process as process 1, or it might be a different process; process 10
    might be the same process as process 5, or it might be a different process.
    In this case, how can we reliably form a communication channel within
    process 9 and 10, without affecting the communication channel within
    process 1, 2, ..., 8?

    One possible solution is to figure out if process 9 and 10 are the same
    as process 1 and 5 beforehand, and then form a communication channel
    based on the information, adjusting the ranks and world_size etc. However,
    figuring out the information is not always easy, and it will interfere
    with the main communication channel.

    Our solution is to always form a communication channel with process 1, 2,
    ..., 8, and then use this function to form another communication channel
    with process 9 and 10. This way, regardless of whether process 9 and 10
    are the same as process 1 and 5, the main communication channel is
    always formed with process 1, 2, ..., 8, and the additional communication
    channel is formed with process 9 and 10.
    """
492
    init_method = get_tcp_uri(host, port)
493
494
495
496
    backend = Backend(backend)  # it is basically string
    timeout = _get_default_timeout(backend)

    store, rank, world_size = next(
497
498
        rendezvous(init_method, rank, world_size, timeout=timeout)
    )
499
500
501
502
503
504
505
506
507
508
    store.set_timeout(timeout)

    group_rank = rank
    group_size = world_size

    # Use a PrefixStore to avoid accidental overrides of keys used by
    # different systems (e.g. RPC) in case the store is multi-tenant.
    prefix_store = PrefixStore(init_method, store)

    if backend == "gloo":
509
510
511
512
513
514
515
        return init_gloo_process_group(
            backend=backend,
            prefix_store=prefix_store,
            group_rank=group_rank,
            group_size=group_size,
            timeout=timeout,
        )
516
    from vllm.platforms import current_platform
517

518
519
520
521
522
    return current_platform.stateless_init_device_torch_dist_pg(
        backend=backend,
        prefix_store=prefix_store,
        group_rank=group_rank,
        group_size=group_size,
523
524
        timeout=timeout,
    )
525
526


527
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
528
529
530
531
    """
    Destroy ProcessGroup returned by
        stateless_init_torch_distributed_process_group().
    """
532
533
534
535
    if is_torch_equal_or_newer("2.7"):
        pg.shutdown()
    else:
        # Lazy import for non-CUDA backends.
536
        from torch.distributed.distributed_c10d import _shutdown_backend
537

538
        _shutdown_backend(pg)
539

540
    _unregister_process_group(pg.group_name)