utils.py 24.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
import dataclasses
9
import functools
10
import os
11
import pickle
12
import socket
13
import sys
14
import time
15
import uuid
16
from collections import deque
17
from collections.abc import Sequence
18
from datetime import timedelta
19
from typing import Any
Zhuohan Li's avatar
Zhuohan Li committed
20
21

import torch
22
from torch.distributed import ProcessGroup, Store, TCPStore
23
24
25
26
27
28
from torch.distributed.distributed_c10d import (
    Backend,
    PrefixStore,
    _get_default_timeout,
    _unregister_process_group,
)
29
from torch.distributed.rendezvous import rendezvous
30

31
32
import vllm.envs as envs
from vllm.logger import init_logger
33
from vllm.utils.network_utils import get_tcp_uri
34
from vllm.utils.system_utils import suppress_stdout
35
36
37

logger = init_logger(__name__)

38
39
40
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
41
42
43
USE_SCHED_YIELD = (sys.version_info[:3] >= (3, 11, 1)) or (
    sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8
)
44
45
46
47
48
49
50
51


def sched_yield():
    if USE_SCHED_YIELD:
        os.sched_yield()
    else:
        time.sleep(0)

Zhuohan Li's avatar
Zhuohan Li committed
52

53
54
55
def ensure_divisibility(numerator, denominator):
    """Ensure that numerator is divisible by the denominator."""
    assert numerator % denominator == 0, "{} is not divisible by {}".format(
56
57
        numerator, denominator
    )
58
59
60
61
62
63
64
65


def divide(numerator, denominator):
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator

Zhuohan Li's avatar
Zhuohan Li committed
66
67
68
69
70

def split_tensor_along_last_dim(
    tensor: torch.Tensor,
    num_partitions: int,
    contiguous_split_chunks: bool = False,
71
) -> Sequence[torch.Tensor]:
72
    """Split a tensor along its last dimension.
Zhuohan Li's avatar
Zhuohan Li committed
73

74
75
76
77
78
    Arguments:
        tensor: input tensor.
        num_partitions: number of partitions to split the tensor
        contiguous_split_chunks: If True, make each chunk contiguous
                                 in memory.
Zhuohan Li's avatar
Zhuohan Li committed
79

80
81
    Returns:
        A list of Tensors
Zhuohan Li's avatar
Zhuohan Li committed
82
83
84
85
86
87
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = divide(tensor.size()[last_dim], num_partitions)
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
88
    # NOTE: torch.split does not create contiguous tensors by default.
Zhuohan Li's avatar
Zhuohan Li committed
89
90
91
92
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list
93
94


95
96
97
def get_pp_indices(
    num_hidden_layers: int, pp_rank: int, pp_size: int
) -> tuple[int, int]:
98
    """Try to evenly distribute layers across partitions.
99

100
    If the number of layers is not divisible by the number of partitions,
101
102
103
104
105
106
107
108
109
    the remaining layers are evenly distributed across all but the last
    partition. The last partition is excluded because it often contains an
    additional norm layer and we are attempting to balance compute.

    If `pp_size > 2` and the number of remaining layers is
    `0 < x <= pp_size - 2` then the remaining layers are evenly distributed
    across the middle partitions. The first and last partitions are excluded
    because they contain the input and output embeddings respectively and we
    are attempting to reduce maximum memory consumption across partitions.
110
    """
111
112
113
    partition_list_str = envs.VLLM_PP_LAYER_PARTITION
    if partition_list_str is not None:
        try:
114
            partitions = [int(layer) for layer in partition_list_str.split(",")]
115
        except ValueError as err:
116
117
118
            raise ValueError(
                "Invalid partition string: {}".format(partition_list_str)
            ) from err
119
120
121
        if len(partitions) != pp_size:
            raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
        if sum(partitions) != num_hidden_layers:
122
            raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
123
124
    else:
        layers_per_partition = num_hidden_layers // pp_size
125
126
127
128
129
        partitions = [layers_per_partition for _ in range(pp_size)]

        if remaining_layers := num_hidden_layers % pp_size:
            for i in range(2, remaining_layers + 2):
                partitions[-i] += 1
130
131
132
133
            logger.info(
                "Hidden layers were unevenly partitioned: [%s]. "
                "This can be manually overridden using the "
                "VLLM_PP_LAYER_PARTITION environment variable",
134
135
                ",".join(str(p) for p in partitions),
            )
136
137
138

    start_layer = sum(partitions[:pp_rank])
    end_layer = start_layer + partitions[pp_rank]
139

140
    return (start_layer, end_layer)
141
142


143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def create_tcp_store(
    host: str,
    port: int,
    listen_socket: socket.socket | None = None,
    **kwargs: Any,
) -> TCPStore:
    """Create a TCPStore, optionally taking ownership of ``listen_socket``."""
    if listen_socket is None:
        return TCPStore(host_name=host, port=port, **kwargs)

    listen_fd = listen_socket.detach()
    try:
        return TCPStore(
            host_name=host,
            port=port,
            master_listen_fd=listen_fd,
            **kwargs,
        )
    except Exception:
        socket.close(listen_fd)
        raise


166
167
168
169
170
171
@dataclasses.dataclass
class StatelessProcessGroup:
    """A dataclass to hold a metadata store, and the rank, world_size of the
    group. Only use it to communicate metadata between processes.
    For data-plane communication, create NCCL-related objects.
    """
172

173
174
175
    rank: int
    world_size: int
    store: torch._C._distributed_c10d.Store
176

177
178
179
    data_expiration_seconds: int = 3600  # 1 hour

    # dst rank -> counter
180
    send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict)
181
    # src rank -> counter
182
    recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
183
    broadcast_send_counter: int = 0
184
    broadcast_recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict)
185
186

    # A deque to store the data entries, with key and timestamp.
187
    entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque)
188
189
190
191
192

    def __post_init__(self):
        assert self.rank < self.world_size
        self.send_dst_counter = {i: 0 for i in range(self.world_size)}
        self.recv_src_counter = {i: 0 for i in range(self.world_size)}
193
        self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
194
195
196
197

    def send_obj(self, obj: Any, dst: int):
        """Send an object to a destination rank."""
        self.expire_data()
198
        key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        self.store.set(key, pickle.dumps(obj))
        self.send_dst_counter[dst] += 1
        self.entries.append((key, time.time()))

    def expire_data(self):
        """Expire data that is older than `data_expiration_seconds` seconds."""
        while self.entries:
            # check the oldest entry
            key, timestamp = self.entries[0]
            if time.time() - timestamp > self.data_expiration_seconds:
                self.store.delete_key(key)
                self.entries.popleft()
            else:
                break

    def recv_obj(self, src: int) -> Any:
        """Receive an object from a source rank."""
        obj = pickle.loads(
217
218
            self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
        )
219
220
221
        self.recv_src_counter[src] += 1
        return obj

222
    def broadcast_obj(self, obj: Any | None, src: int) -> Any:
223
224
225
226
227
228
        """Broadcast an object from a source rank to all other ranks.
        It does not clean up after all ranks have received the object.
        Use it for limited times, e.g., for initialization.
        """
        if self.rank == src:
            self.expire_data()
229
            key = f"broadcast_from/{src}/{self.broadcast_send_counter}"
230
231
232
233
234
            self.store.set(key, pickle.dumps(obj))
            self.broadcast_send_counter += 1
            self.entries.append((key, time.time()))
            return obj
        else:
235
            key = f"broadcast_from/{src}/{self.broadcast_recv_src_counter[src]}"
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            recv_obj = pickle.loads(self.store.get(key))
            self.broadcast_recv_src_counter[src] += 1
            return recv_obj

    def all_gather_obj(self, obj: Any) -> list[Any]:
        """All gather an object from all ranks."""
        gathered_objs = []
        for i in range(self.world_size):
            if i == self.rank:
                gathered_objs.append(obj)
                self.broadcast_obj(obj, src=self.rank)
            else:
                recv_obj = self.broadcast_obj(None, src=i)
                gathered_objs.append(recv_obj)
        return gathered_objs

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
        """Broadcast a tensor from source rank to all other ranks."""
        if self.rank == src:
            tensor_bytes = pickle.dumps(tensor)
            self.expire_data()
            key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
            self.store.set(key, tensor_bytes)
            self.broadcast_send_counter += 1
            self.entries.append((key, time.time()))
            return tensor
        else:
            key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
            tensor = pickle.loads(self.store.get(key))
            self.broadcast_recv_src_counter[src] += 1
            return tensor

    def send(self, tensor: torch.Tensor, dst: int):
        """Send a tensor to a destination rank."""
        self.expire_data()
        key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
        self.store.set(key, pickle.dumps(tensor))
        self.send_dst_counter[dst] += 1
        self.entries.append((key, time.time()))

    def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
        """Receive a tensor from a source rank."""
        key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
        received = pickle.loads(self.store.get(key))
        self.recv_src_counter[src] += 1
        tensor.copy_(received)
        return tensor

    def all_reduce(
        self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
    ) -> torch.Tensor:
        """All-reduce a tensor across all ranks."""
        tensors = self.all_gather_obj(tensor)
        result = tensors[0].clone()
        for t in tensors[1:]:
            if op == torch.distributed.ReduceOp.SUM:
                result.add_(t)
            elif op == torch.distributed.ReduceOp.PRODUCT:
                result.mul_(t)
            elif op == torch.distributed.ReduceOp.MAX:
                result = torch.maximum(result, t)
            elif op == torch.distributed.ReduceOp.MIN:
                result = torch.minimum(result, t)
        return result

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    def barrier(self, timeout: float = 30.0):
        """A robust barrier to synchronize all ranks.


        Uses a multi-phase approach to ensure all processes reach the barrier
        before proceeding:

        1. Each process signals it has reached the barrier

        2. Each process signals that it has confirmed the arrival of all other
        ranks.

        3. Rank 0 waits for all other ranks to signal their departure to ensure
        that all ranks have departed the barrier first.

        Args:
            timeout: Maximum time in seconds to wait for each phase (in seconds)


        Raises:
            RuntimeError: If coordination fails or times out
        """
        # Generate a barrier ID that is globally unique
        try:
            if self.rank == 0:
                barrier_id = f"barrier_{uuid.uuid4()}"
                self.broadcast_obj(barrier_id, src=0)
            else:
                barrier_id = self.broadcast_obj(None, src=0)
        except Exception as e:
            raise RuntimeError("Failed to broadcast barrier_id") from e

        # Phase 1: Signal arrival at barrier
        # Wait for all processes to arrive
        # We need all ranks to confirm the arrival of all other ranks.
        # This is the key synchronization point.
        arrival_key = f"arrival_{barrier_id}_{self.rank}"
        try:
            self.store.set(arrival_key, b"1")
        except Exception as e:
            raise RuntimeError("Failed to signal barrier arrival") from e

        start_time = time.time()
        processes_arrived: set[int] = set()

        while len(processes_arrived) < self.world_size:
            # Check for timeout
            cur_time = time.time()
            if cur_time - start_time > timeout:
350
                raise RuntimeError(f"Barrier timed out after {timeout:.2f} seconds")
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396

            # Check for each process
            for i in range(self.world_size):
                if i in processes_arrived:
                    continue

                key = f"arrival_{barrier_id}_{i}"
                try:
                    # Try to get the key - if it exists, we'll get a value
                    # If it doesn't exist, it will throw an exception
                    self.store.get(key)
                    processes_arrived.add(i)
                except KeyError:
                    # Key doesn't exist yet
                    pass
                except Exception as check_e:
                    logger.debug("Error checking key existence: %s", check_e)
                    sched_yield()

            # Short sleep to avoid tight polling
            if len(processes_arrived) < self.world_size:
                sched_yield()

        # Phase 2: Signal departure from barrier
        # We only care to block at this stage in rank 0, which runs the
        # server side of the TCPStore. We want to make sure that all
        # clients have departed the barrier before rank 0 in case the
        # next thing after the barrier is a shutdown, including tearing
        # down the TCPStore. Other ranks can exit the barrier immediately
        # after signaling their departure.
        departure_key = f"departure_{barrier_id}_{self.rank}"
        try:
            self.store.set(departure_key, b"1")
        except Exception as e:
            raise RuntimeError("Failed to signal barrier departure") from e

        if self.rank != 0:
            return

        # Make rank 0 wait for all processes to signal departure
        start_time = time.time()
        processes_departed: set[int] = set()

        while len(processes_departed) < self.world_size:
            # Check for timeout
            if time.time() - start_time > timeout:
397
398
399
                raise RuntimeError(
                    f"Barrier departure timed out after {timeout:.2f} seconds"
                )
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423

            # Check for each process
            for i in range(self.world_size):
                if i in processes_departed:
                    continue

                key = f"departure_{barrier_id}_{i}"
                try:
                    # Try to get the key - if it exists, we'll get a value
                    # If it doesn't exist, it will throw an exception
                    self.store.get(key)
                    processes_departed.add(i)
                except KeyError:
                    # Key doesn't exist yet
                    pass
                except Exception as check_e:
                    logger.debug("Error checking key existence: %s", check_e)
                    sched_yield()

            # Short sleep to avoid tight polling
            if len(processes_departed) < self.world_size:
                sched_yield()

        # Clean up keys to avoid leaking memory in the store
424
        for i in range(self.world_size):
425
426
427
            try:
                self.store.delete_key(f"arrival_{barrier_id}_{i}")
            except Exception:
428
                logger.debug("Error deleting key: %s", f"arrival_{barrier_id}_{i}")
429
430
431
432

            try:
                self.store.delete_key(f"departure_{barrier_id}_{i}")
            except Exception:
433
                logger.debug("Error deleting key: %s", f"departure_{barrier_id}_{i}")
434
435
436

    @staticmethod
    def create(
437
438
        host: str,
        port: int,
439
440
441
        rank: int,
        world_size: int,
        data_expiration_seconds: int = 3600,
442
        store_timeout: int = 300,
443
        listen_socket: socket.socket | None = None,
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    ) -> "StatelessProcessGroup":
        """A replacement for `torch.distributed.init_process_group` that does not
        pollute the global state.

        If we have process A and process B called `torch.distributed.init_process_group`
        to form a group, and then we want to form another group with process A, B, C,
        D, it is not possible in PyTorch, because process A and process B have already
        formed a group, and process C and process D cannot join that group. This
        function is a workaround for this issue.

        `torch.distributed.init_process_group` is a global call, while this function
        is a stateless call. It will return a `StatelessProcessGroup` object that can be
        used for exchanging metadata. With this function, process A and process B
        can call `StatelessProcessGroup.create` to form a group, and then process A, B,
        C, and D can call `StatelessProcessGroup.create` to form another group.
459
        """  # noqa
460
        launch_server = rank == 0
461
        if launch_server and listen_socket is None:
462
463
464
465
            listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            listen_socket.bind((host, port))
            listen_socket.listen()
466
467
468
469
        store = create_tcp_store(
            host,
            port,
            listen_socket=listen_socket,
470
            world_size=world_size,
471
            is_master=launch_server,
472
            timeout=timedelta(seconds=store_timeout),
473
            use_libuv=False,  # for now: github.com/pytorch/pytorch/pull/150215
474
        )
475
476
477
478
479

        return StatelessProcessGroup(
            rank=rank,
            world_size=world_size,
            store=store,
480
481
            data_expiration_seconds=data_expiration_seconds,
        )
482
483


484
485
486
487
488
489
490
491
492
493
@functools.lru_cache(maxsize=1)
def get_cached_tcp_store_client(host: str, port: int) -> TCPStore:
    """Return a cached TCPStore client.

    Cached so that every call with the same ``(host, port)`` reuses the
    same connection.  A new ``(host, port)`` evicts the old entry.
    """
    return TCPStore(host, port, is_master=False, wait_for_workers=False)


494
495
496
497
498
499
def init_gloo_process_group(
    prefix_store: PrefixStore,
    group_rank: int,
    group_size: int,
    timeout: timedelta,
) -> ProcessGroup:
500
    """
501
    Stateless init ProcessGroup with gloo backend compatible with
502
503
    different torch versions.
    """
504
    with suppress_stdout():
505
506
507
508
509
        pg = ProcessGroup(
            prefix_store,
            group_rank,
            group_size,
        )
510
        from torch.distributed.distributed_c10d import ProcessGroupGloo
511

512
513
514
515
516
        backend_class = ProcessGroupGloo(
            prefix_store, group_rank, group_size, timeout=timeout
        )
        backend_type = ProcessGroup.BackendType.GLOO
        device = torch.device("cpu")
517
        pg._set_default_backend(backend_type)
518
519
520
        backend_class._set_sequence_number_for_group()

        pg._register_backend(device, backend_type, backend_class)
521
522
523
    return pg


524
def stateless_init_torch_distributed_process_group(
525
526
527
528
529
530
531
    host: str,
    port: int,
    rank: int,
    world_size: int,
    backend: str,
    group_name: str | None = None,
    return_store: bool = False,
532
    listen_socket: socket.socket | None = None,
533
) -> ProcessGroup | tuple[ProcessGroup, Store]:
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
    """
    A replacement for `torch.distributed.init_process_group` that does not
    pollute the global state. The created ProcessGroup object can be used for
    some operations such as `allreduce`, because it does not depend on the
    global rank. However, some operations such as `broadcast` cannot be used
    because it depends on the global rank.

    # TODO: ask for help from PyTorch team if we need the `broadcast` operation.

    This function is useful when we are not sure about the total number of
    processes in the process group. For example, we may have process
    1, 2, ..., 8 who want to communicate, and process 9 might be the same
    process as process 1, or it might be a different process; process 10
    might be the same process as process 5, or it might be a different process.
    In this case, how can we reliably form a communication channel within
    process 9 and 10, without affecting the communication channel within
    process 1, 2, ..., 8?

    One possible solution is to figure out if process 9 and 10 are the same
    as process 1 and 5 beforehand, and then form a communication channel
    based on the information, adjusting the ranks and world_size etc. However,
    figuring out the information is not always easy, and it will interfere
    with the main communication channel.

    Our solution is to always form a communication channel with process 1, 2,
    ..., 8, and then use this function to form another communication channel
    with process 9 and 10. This way, regardless of whether process 9 and 10
    are the same as process 1 and 5, the main communication channel is
    always formed with process 1, 2, ..., 8, and the additional communication
    channel is formed with process 9 and 10.
564
565
566
567
568

    When *listen_socket* is provided, the rendezvous step
    is skipped and a ``TCPStore`` server is created directly using the
    pre-bound socket.  This is useful for eliminating TOCTOU races
    between port allocation and binding.
569
    """
570
    init_method = get_tcp_uri(host, port)
571
572
573
    backend = Backend(backend)  # it is basically string
    timeout = _get_default_timeout(backend)

574
575
576
577
578
579
580
581
582
583
584
585
586
587
    if listen_socket is not None:
        store = create_tcp_store(
            host,
            port,
            listen_socket=listen_socket,
            world_size=world_size,
            is_master=True,
            timeout=timeout,
            multi_tenant=True,
        )
    else:
        store, rank, world_size = next(
            rendezvous(init_method, rank, world_size, timeout=timeout)
        )
588
589
590
591
592
593
594
595
596
    store.set_timeout(timeout)

    group_rank = rank
    group_size = world_size

    # Use a PrefixStore to avoid accidental overrides of keys used by
    # different systems (e.g. RPC) in case the store is multi-tenant.
    prefix_store = PrefixStore(init_method, store)

597
598
    if backend == "gloo":
        pg = init_gloo_process_group(
599
600
601
602
603
            prefix_store=prefix_store,
            group_rank=group_rank,
            group_size=group_size,
            timeout=timeout,
        )
604
605
606
607
608
    else:
        from vllm.platforms import current_platform

        pg = current_platform.stateless_init_device_torch_dist_pg(
            backend=backend,
609
610
611
612
613
            prefix_store=prefix_store,
            group_rank=group_rank,
            group_size=group_size,
            timeout=timeout,
        )
614

615
616
617
618
619
620
621
622
623
624
625
    if group_name is not None:
        from torch._C._distributed_c10d import _register_process_group

        pg._set_group_name(group_name)
        _register_process_group(group_name, pg)

    if return_store:
        return pg, store
    else:
        return pg

626

627
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
628
629
630
631
    """
    Destroy ProcessGroup returned by
        stateless_init_torch_distributed_process_group().
    """
632
    pg.shutdown()
633
    _unregister_process_group(pg.group_name)
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673


def get_worker_rank_suffix(global_rank: int | None = None) -> str:
    """Generate a descriptive rank suffix for worker identification.

    Returns a string like 'dp0_pp0_tp0_dcp0_ep0_rank0' including all
    parallel dimensions: DP, PP, TP, DCP, EP.

    Args:
        global_rank: Optional global rank to append. If not provided,
                     only parallel dimension ranks are included.

    Returns:
        A string suffix identifying the worker's position in the
        distributed topology.
    """
    from vllm.distributed.parallel_state import (
        get_dcp_group,
        get_dp_group,
        get_ep_group,
        get_pp_group,
        get_tp_group,
    )

    try:
        dp_rank = get_dp_group().rank_in_group
        pp_rank = get_pp_group().rank_in_group
        tp_rank = get_tp_group().rank_in_group
        dcp_rank = get_dcp_group().rank_in_group
        ep_rank = get_ep_group().rank_in_group

        suffix = f"dp{dp_rank}_pp{pp_rank}_tp{tp_rank}_dcp{dcp_rank}_ep{ep_rank}"
        if global_rank is not None:
            suffix = f"{suffix}_rank{global_rank}"
        return suffix
    except Exception:
        # Fallback if parallel state not initialized
        if global_rank is not None:
            return f"rank{global_rank}"
        return ""