"vllm/v1/executor/ray_utils.py" did not exist on "9925c179409f7b14aef55bcfde47b320bd1a263c"
utils.py 26.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import argparse
import multiprocessing
5
import time
6
import weakref
7
from collections import defaultdict
8
from collections.abc import Sequence
9
10
from dataclasses import dataclass
from enum import Enum, auto
11
from multiprocessing import Process, connection
12
13
14
from multiprocessing.process import BaseProcess
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
                    Union, overload)
15

16
import msgspec
17
import torch
18
import zmq
19

20
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
21
from vllm.logger import init_logger
22
from vllm.model_executor.models.utils import extract_layer_index
23
24
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
                                  usage_message)
25
26
from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
                        get_tcp_uri, kill_process_tree)
27
from vllm.v1.executor.abstract import Executor
28

29
if TYPE_CHECKING:
Rui Qiao's avatar
Rui Qiao committed
30
31
    from ray.util.placement_group import PlacementGroup

32
    from vllm.attention.layer import Attention
33
    from vllm.v1.engine.coordinator import DPCoordinator
34

35
logger = init_logger(__name__)
36
37
38

T = TypeVar("T")

39
40
STARTUP_POLL_PERIOD_MS = 10000

41

42
class ConstantList(Generic[T], Sequence):
43

44
    def __init__(self, x: list[T]) -> None:
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        self._x = x

    def append(self, item):
        raise Exception("Cannot append to a constant list")

    def extend(self, item):
        raise Exception("Cannot extend a constant list")

    def insert(self, item):
        raise Exception("Cannot insert into a constant list")

    def pop(self, item):
        raise Exception("Cannot pop from a constant list")

    def remove(self, item):
        raise Exception("Cannot remove from a constant list")

    def clear(self):
        raise Exception("Cannot clear a constant list")

65
66
67
68
69
70
    def index(self,
              item: T,
              start: int = 0,
              stop: Optional[int] = None) -> int:
        return self._x.index(item, start,
                             stop if stop is not None else len(self._x))
71
72

    @overload
73
    def __getitem__(self, item: int) -> T:
74
75
76
        ...

    @overload
77
    def __getitem__(self, s: slice, /) -> list[T]:
78
79
        ...

80
    def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
81
82
83
        return self._x[item]

    @overload
84
    def __setitem__(self, item: int, value: T):
85
86
87
        ...

    @overload
88
    def __setitem__(self, s: slice, value: T, /):
89
90
        ...

91
    def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
92
93
94
95
96
97
98
99
100
101
102
103
104
        raise Exception("Cannot set item in a constant list")

    def __delitem__(self, item):
        raise Exception("Cannot delete item from a constant list")

    def __iter__(self):
        return iter(self._x)

    def __contains__(self, item):
        return item in self._x

    def __len__(self):
        return len(self._x)
105

106
107
108
    def __repr__(self):
        return f"ConstantList({self._x})"

109

110
111
112
113
114
115
116
def get_engine_client_zmq_addr(local_only: bool,
                               host: str,
                               port: int = 0) -> str:
    return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
        host, port or get_open_port()))


Rui Qiao's avatar
Rui Qiao committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class CoreEngineState(Enum):
    NEW = auto()
    CONNECTED = auto()
    READY = auto()


class CoreEngine:
    """One per data parallel rank."""

    def __init__(self, index: int = 0, local: bool = True):
        self.local = local
        self.index = index
        self.identity = index.to_bytes(2, "little")

        self.state = CoreEngineState.NEW


@dataclass
class EngineZmqAddresses:
    # ZMQ input socket addresses for each front-end client (requests)
    inputs: list[str]
    # ZMQ output socket addresses for each front-end client (responses)
    outputs: list[str]
    # ZMQ input socket address of DP coordinator if applicable
    coordinator_input: Optional[str] = None
    # ZMQ output socket address of DP coordinator if applicable
    coordinator_output: Optional[str] = None


@dataclass
class EngineHandshakeMetadata:
    """Metadata sent to each engine process during startup handshake,
    including addresses of the front-end ZMQ queues that they should
    connect to.
    """
    addresses: EngineZmqAddresses
    parallel_config: dict[str, Union[int, str]]


156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class APIServerProcessManager:
    """Manages a group of API server processes.
    
    Handles creation, monitoring, and termination of API server worker
    processes. Also monitors extra processes to check if they are healthy.
    """

    def __init__(
        self,
        target_server_fn: Callable,
        listen_address: str,
        sock: Any,
        args: argparse.Namespace,
        num_servers: int,
        input_addresses: list[str],
        output_addresses: list[str],
        stats_update_address: Optional[str] = None,
    ):
        """Initialize and start API server worker processes.
        
        Args:
            target_server_fn: Function to call for each API server process
            listen_address: Address to listen for client connections
            sock: Socket for client connections
            args: Command line arguments
            num_servers: Number of API server processes to start
            input_addresses: Input addresses for each API server
            output_addresses: Output addresses for each API server
            stats_update_address: Optional stats update address 
        """
        self.listen_address = listen_address
        self.sock = sock
        self.args = args

        # Start API servers
        spawn_context = multiprocessing.get_context("spawn")
        self.processes: list[BaseProcess] = []

        for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
                                        output_addresses):
            client_config = {
                "input_address": in_addr,
                "output_address": out_addr,
                "client_index": i
            }
            if stats_update_address is not None:
                client_config["stats_update_address"] = stats_update_address

            proc = spawn_context.Process(target=target_server_fn,
                                         name=f"ApiServer_{i}",
                                         args=(listen_address, sock, args,
                                               client_config))
            self.processes.append(proc)
            proc.start()

        logger.info("Started %d API server processes", len(self.processes))

        # Shutdown only the API server processes on garbage collection
        # The extra processes are managed by their owners
        self._finalizer = weakref.finalize(self, shutdown, self.processes)

    def close(self) -> None:
        self._finalizer()


221
class CoreEngineProcManager:
222
223
224
225
226
227
228
229
    """
    Utility class to handle creation, readiness, and shutdown
    of background processes used by the AsyncLLM and LLMEngine.
    """

    def __init__(
        self,
        target_fn: Callable,
230
231
232
233
234
        local_engine_count: int,
        start_index: int,
        local_start_index: int,
        vllm_config: VllmConfig,
        on_head_node: bool,
235
        handshake_address: str,
236
237
        executor_class: type[Executor],
        log_stats: bool,
238
239
    ):
        context = get_mp_context()
240
241
242
        common_kwargs = {
            "vllm_config": vllm_config,
            "on_head_node": on_head_node,
243
            "handshake_address": handshake_address,
244
245
246
247
            "executor_class": executor_class,
            "log_stats": log_stats,
        }

248
        self.processes: list[BaseProcess] = []
249
250
251
252
253
254
255
256
257
258
259
260
        for index in range(local_engine_count):
            local_index = local_start_index + index
            global_index = start_index + index
            # Start EngineCore in background process.
            self.processes.append(
                context.Process(target=target_fn,
                                name=f"EngineCore_{global_index}",
                                kwargs=common_kwargs | {
                                    "dp_rank": global_index,
                                    "local_dp_rank": local_index,
                                }))

261
        self._finalizer = weakref.finalize(self, shutdown, self.processes)
262
263
264
265
266
267
268
269
270
271
272
        try:
            for proc in self.processes:
                proc.start()
        finally:
            # Kill other procs if not all are running.
            if self.finished_procs():
                self.close()

    def close(self):
        """Shutdown all procs."""
        self._finalizer()
273

274
275
276
    def join_first(self):
        """Wait for any process to exit."""
        connection.wait(proc.sentinel for proc in self.processes)
277

278
279
    def sentinels(self) -> list:
        return [proc.sentinel for proc in self.processes]
280

281
282
283
284
285
286
    def finished_procs(self) -> dict[str, int]:
        """Returns dict of proc name -> exit code for any finished procs."""
        return {
            proc.name: proc.exitcode
            for proc in self.processes if proc.exitcode is not None
        }
Robert Shaw's avatar
Robert Shaw committed
287
288


Rui Qiao's avatar
Rui Qiao committed
289
290
291
292
class CoreEngineActorManager:
    """
    Utility class to handle creation, readiness, and shutdown
    of core engine Ray actors used by the AsyncLLM and LLMEngine.
293

Rui Qiao's avatar
Rui Qiao committed
294
295
296
    Different from CoreEngineProcManager, this class manages
    core engines for both local and remote nodes.
    """
297

Rui Qiao's avatar
Rui Qiao committed
298
299
300
301
302
303
304
305
306
307
    def __init__(
        self,
        vllm_config: VllmConfig,
        addresses: EngineZmqAddresses,
        executor_class: type[Executor],
        log_stats: bool,
        placement_groups: Optional[list["PlacementGroup"]] = None,
        local_dp_ranks: Optional[list[int]] = None,
    ):
        import copy
308

Rui Qiao's avatar
Rui Qiao committed
309
310
311
        import ray
        from ray.util.scheduling_strategies import (
            PlacementGroupSchedulingStrategy)
312

Rui Qiao's avatar
Rui Qiao committed
313
        from vllm.v1.engine.core import DPEngineCoreActor
314

Rui Qiao's avatar
Rui Qiao committed
315
316
317
318
319
320
        self.local_engine_actors: list[ray.ActorHandle] = []
        self.remote_engine_actors: list[ray.ActorHandle] = []
        dp_size = vllm_config.parallel_config.data_parallel_size
        local_engine_count = \
            vllm_config.parallel_config.data_parallel_size_local
        world_size = vllm_config.parallel_config.world_size
321

Rui Qiao's avatar
Rui Qiao committed
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        if ray.is_initialized():
            logger.info(
                "Ray is already initialized. Skipping Ray initialization.")
        else:
            ray.init()

        if placement_groups is not None:
            assert local_dp_ranks is not None, (
                "local_dp_ranks must be provided if "
                "placement_groups is provided")
            assert len(placement_groups) == len(local_dp_ranks), (
                "placement_groups and local_dp_ranks must "
                "have the same length")
            logger.info("Using provided placement groups")
            # TODO(rui): validate passed-in placement groups
            self.created_placement_groups = []
        else:
            placement_groups, local_dp_ranks = \
                CoreEngineActorManager.create_dp_placement_groups(vllm_config)
            self.created_placement_groups = placement_groups
        assert len(placement_groups) == dp_size, (
            "Number of placement groups must match data parallel size")

        refs = []
        for index in range(dp_size):
            local_index = local_dp_ranks[index]
            dp_vllm_config = copy.deepcopy(vllm_config)
            pg = placement_groups[index]
            dp_vllm_config.parallel_config.placement_group = pg
            on_head_node = index < local_engine_count
            actor = ray.remote(DPEngineCoreActor).options(
                scheduling_strategy=PlacementGroupSchedulingStrategy(
                    placement_group=pg,
                    placement_group_bundle_index=world_size,
                )).remote(vllm_config=dp_vllm_config,
                          executor_class=executor_class,
                          log_stats=log_stats,
                          on_head_node=on_head_node,
                          addresses=addresses,
                          dp_rank=index,
                          local_dp_rank=local_index)
            if on_head_node:
                self.local_engine_actors.append(actor)
            else:
                self.remote_engine_actors.append(actor)
            refs.append(actor.wait_for_init.remote())

        ray.get(refs)
        self.run_refs = []
        for actor in self.local_engine_actors + self.remote_engine_actors:
            self.run_refs.append(actor.run.remote())

    @staticmethod
    def create_dp_placement_groups(
            vllm_config: VllmConfig
    ) -> tuple[list["PlacementGroup"], list[int]]:

        import ray
        from ray._private.state import available_resources_per_node
        from ray.util.state import list_nodes

        logger.info("Creating placement groups for data parallel")
        dp_master_ip = \
            vllm_config.parallel_config.data_parallel_master_ip
        dp_size = vllm_config.parallel_config.data_parallel_size
        local_engine_count = \
            vllm_config.parallel_config.data_parallel_size_local

        nodes = list_nodes()
        nodes = sorted(list_nodes(),
                       key=lambda node: node.node_ip != dp_master_ip)
        assert nodes[0].node_ip == dp_master_ip, (
            "The first node must be the head node")
        assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
            "There can only be one head node")

        available_resources = available_resources_per_node()
        world_size = vllm_config.parallel_config.world_size
        placement_groups: list[PlacementGroup] = []
        local_dp_ranks: list[int] = []

        for node in nodes:
            node_ip = node.node_ip
            node_resources = available_resources[node.node_id]
            # For now, each DP rank can only be assigned to one node
            # TODO(rui): support allocating a single DP rank
            # to multiple nodes
            available_engine_count = node_resources["GPU"] // world_size
            if node_ip == dp_master_ip:
                assert available_engine_count >= local_engine_count, (
                    "Not enough resources to allocate DP ranks "
                    f"on DP master node {node_ip}")
                for i in range(local_engine_count):
                    bundles = [{
                        "GPU": 1.0,
                        "node:" + dp_master_ip: 0.001
                    }] * world_size + [{
                        "CPU": 1.0
                    }]
                    pg = ray.util.placement_group(
                        name=f"dp_rank_{len(placement_groups)}",
                        strategy="STRICT_PACK",
                        bundles=bundles,
                    )
                    placement_groups.append(pg)
                    local_dp_ranks.append(i)
            else:
                for i in range(available_engine_count):
                    if len(placement_groups) == dp_size:
                        break
                    bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
                    pg = ray.util.placement_group(
                        name=f"dp_rank_{len(placement_groups)}",
                        strategy="STRICT_PACK",
                        bundles=bundles,
                    )
                    placement_groups.append(pg)
                    local_dp_ranks.append(i)
        return placement_groups, local_dp_ranks

    def get_run_refs(self):
        return self.run_refs
444

Rui Qiao's avatar
Rui Qiao committed
445
446
447
448
449
450
    def close(self):
        import ray
        for actor in self.local_engine_actors + self.remote_engine_actors:
            ray.kill(actor)
        for pg in self.created_placement_groups:
            ray.util.remove_placement_group(pg)
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551


def wait_for_engine_startup(
    handshake_socket: zmq.Socket,
    addresses: EngineZmqAddresses,
    core_engines: list[CoreEngine],
    parallel_config: ParallelConfig,
    cache_config: CacheConfig,
    proc_manager: Optional[CoreEngineProcManager],
    coord_process: Optional[Process],
):

    # Wait for engine core process(es) to send ready messages.
    local_count = parallel_config.data_parallel_size_local
    remote_count = len(core_engines) - local_count
    # [local, remote] counts
    conn_pending, start_pending = [local_count, remote_count], [0, 0]
    poller = zmq.Poller()
    poller.register(handshake_socket, zmq.POLLIN)

    if proc_manager is not None:
        for sentinel in proc_manager.sentinels():
            poller.register(sentinel, zmq.POLLIN)
    if coord_process is not None:
        poller.register(coord_process.sentinel, zmq.POLLIN)
    while any(conn_pending) or any(start_pending):
        events = poller.poll(STARTUP_POLL_PERIOD_MS)
        if not events:
            if any(conn_pending):
                logger.debug(
                    "Waiting for %d local, %d remote core engine proc(s) "
                    "to connect.", *conn_pending)
            if any(start_pending):
                logger.debug(
                    "Waiting for %d local, %d remote core engine proc(s) "
                    "to start.", *start_pending)
            continue
        if len(events) > 1 or events[0][0] != handshake_socket:
            # One of the local core processes exited.
            finished = proc_manager.finished_procs() if proc_manager else {}
            if coord_process is not None and coord_process.exitcode is not None:
                finished[coord_process.name] = coord_process.exitcode
            raise RuntimeError("Engine core initialization failed. "
                               "See root cause above. "
                               f"Failed core proc(s): {finished}")

        # Receive HELLO and READY messages from the input socket.
        eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
        eng_index = int.from_bytes(eng_identity, "little")
        engine = next((e for e in core_engines if e.identity == eng_identity),
                      None)
        if engine is None:
            raise RuntimeError(f"Message from engine with unexpected data "
                               f"parallel rank: {eng_index}")
        msg = msgspec.msgpack.decode(ready_msg_bytes)
        status, local = msg["status"], msg["local"]
        if local != engine.local:
            raise RuntimeError(f"{status} message from "
                               f"{'local' if local else 'remote'} "
                               f"engine {eng_index}, expected it to be "
                               f"{'local' if engine.local else 'remote'}")

        if status == "HELLO" and engine.state == CoreEngineState.NEW:

            # Send init message with DP config info.
            init_message = msgspec.msgpack.encode(
                EngineHandshakeMetadata(
                    addresses=addresses,
                    parallel_config={
                        "data_parallel_master_ip":
                        parallel_config.data_parallel_master_ip,
                        "data_parallel_master_port":
                        parallel_config.data_parallel_master_port,
                        "data_parallel_size":
                        parallel_config.data_parallel_size,
                    }))
            handshake_socket.send_multipart((eng_identity, init_message),
                                            copy=False)
            conn_pending[0 if local else 1] -= 1
            start_pending[0 if local else 1] += 1
            engine.state = CoreEngineState.CONNECTED
        elif status == "READY" and (engine.state == CoreEngineState.CONNECTED):
            # Setup KV cache config with initialization state from
            # engine core process. Sum values from all engines in DP case.
            num_gpu_blocks = cache_config.num_gpu_blocks or 0
            num_gpu_blocks += msg["num_gpu_blocks"]
            cache_config.num_gpu_blocks = num_gpu_blocks

            start_pending[0 if local else 1] -= 1
            engine.state = CoreEngineState.READY
        else:
            raise RuntimeError(f"Unexpected {status} message for "
                               f"{'local' if local else 'remote'} engine "
                               f"{eng_index} in {engine.state} state.")

        logger.debug("%s from %s core engine process %s.", status,
                     "local" if local else "remote", eng_index)


def wait_for_completion_or_failure(
        api_server_manager: APIServerProcessManager,
Rui Qiao's avatar
Rui Qiao committed
552
553
        engine_manager: Optional[Union[CoreEngineProcManager,
                                       CoreEngineActorManager]] = None,
554
555
556
557
        coordinator: Optional["DPCoordinator"] = None) -> None:
    """Wait for all processes to complete or detect if any fail.
    
    Raises an exception if any process exits with a non-zero status.
Rui Qiao's avatar
Rui Qiao committed
558
559
560
561
562
563
564

    Args:
        api_server_manager: The manager for API servers.
        engine_manager: The manager for engine processes.
            If CoreEngineProcManager, it manages local engines;
            if CoreEngineActorManager, it manages all engines.
        coordinator: The coordinator for data parallel.
565
566
567
568
569
570
571
572
573
574
575
576
577
578
    """

    try:
        logger.info("Waiting for API servers to complete ...")
        # Create a mapping of sentinels to their corresponding processes
        # for efficient lookup
        sentinel_to_proc: dict[Any, BaseProcess] = {
            proc.sentinel: proc
            for proc in api_server_manager.processes
        }

        if coordinator:
            sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc

Rui Qiao's avatar
Rui Qiao committed
579
580
581
        actor_run_refs = []
        if isinstance(engine_manager, CoreEngineProcManager):
            for proc in engine_manager.processes:
582
                sentinel_to_proc[proc.sentinel] = proc
Rui Qiao's avatar
Rui Qiao committed
583
584
        elif isinstance(engine_manager, CoreEngineActorManager):
            actor_run_refs = engine_manager.get_run_refs()
585
586

        # Check if any process terminates
Rui Qiao's avatar
Rui Qiao committed
587
        while sentinel_to_proc or actor_run_refs:
588
            # Wait for any process to terminate
Rui Qiao's avatar
Rui Qiao committed
589
590
            ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
                                                         timeout=5)
591
592
593
594
595
596
597
598
599
600

            # Process any terminated processes
            for sentinel in ready_sentinels:
                proc = sentinel_to_proc.pop(sentinel)

                # Check if process exited with error
                if proc.exitcode != 0:
                    raise RuntimeError(
                        f"Process {proc.name} (PID: {proc.pid}) "
                        f"died with exit code {proc.exitcode}")
Rui Qiao's avatar
Rui Qiao committed
601
602
603
604
605

            if actor_run_refs:
                import ray
                _, actor_run_refs = ray.wait(actor_run_refs, timeout=5)

606
607
608
609
610
611
612
613
614
615
616
    except KeyboardInterrupt:
        logger.info("Received KeyboardInterrupt, shutting down API servers...")
    except Exception as e:
        logger.exception("Exception occurred while running API servers: %s",
                         str(e))
        raise
    finally:
        logger.info("Terminating remaining processes ...")
        api_server_manager.close()
        if coordinator:
            coordinator.close()
Rui Qiao's avatar
Rui Qiao committed
617
618
        if engine_manager:
            engine_manager.close()
619
620


Robert Shaw's avatar
Robert Shaw committed
621
# Note(rob): shutdown function cannot be a bound method,
622
623
# else the gc cannot collect the object.
def shutdown(procs: list[BaseProcess]):
Robert Shaw's avatar
Robert Shaw committed
624
    # Shutdown the process.
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    for proc in procs:
        if proc.is_alive():
            proc.terminate()

    # Allow 5 seconds for remaining procs to terminate.
    deadline = time.monotonic() + 5
    for proc in procs:
        remaining = deadline - time.monotonic()
        if remaining <= 0:
            break
        if proc.is_alive():
            proc.join(remaining)

    for proc in procs:
639
640
        if proc.is_alive() and (pid := proc.pid) is not None:
            kill_process_tree(pid)
Robert Shaw's avatar
Robert Shaw committed
641

642
643

def bind_kv_cache(
644
645
646
    kv_caches: dict[str, torch.Tensor],
    forward_context: dict[str, "Attention"],
    runner_kv_caches: list[torch.Tensor],
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
) -> None:
    """
    Bind the allocated KV cache to both ModelRunner and forward context so
    that the KV cache can be used in the forward pass.

    This function:
      1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
         kv_caches.
      2) Associates each attention layer in the `forward_context` with its 
         corresponding KV cache in kv_caches.

    Args:
        kv_caches: The allocated kv_caches with layer names as keys.
        forward_context: The global forward context containing all Attention 
        layers with layer names as keys.
        runner_kv_caches: The kv_cache declared by ModelRunner.
    """
    # Bind kv_caches to ModelRunner
    assert len(runner_kv_caches) == 0

    # Convert kv_caches dict to a list of tensors in the order of layer_index.
    index2name = defaultdict(list)
    for layer_name in kv_caches:
        index2name[extract_layer_index(layer_name)].append(layer_name)

    for layer_index in sorted(index2name.keys()):
        layer_names = index2name[layer_index]
        if len(layer_names) > 1:
            # One typical case is encoder-decoder model, e.g., bart.
            # The cross attention and self attention in the same decoder layer
            # has different layer_name but the same layer_index.
            raise NotImplementedError
        layer_name = layer_names[0]
        runner_kv_caches.append(kv_caches[layer_name])

    # Bind kv_caches to forward context
    for layer_name, kv_cache in kv_caches.items():
        # NOTE: Use list because of v0 PP virtual engine.
        forward_context[layer_name].kv_cache = [kv_cache]
686
687
688


def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
689
               length: int) -> torch.Tensor:
690
691
692
693
694
    """
    Copy the first length elements of a tensor into another tensor in a
    non-blocking manner.

    Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
695
696

    Returns the sliced target tensor.
697
    """
698
    return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
699
700


701
702
703
def report_usage_stats(
        vllm_config,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
    """Report usage statistics if enabled."""

    if not is_usage_stats_enabled():
        return

    from vllm.model_executor.model_loader import get_architecture_class_name

    usage_message.report_usage(
        get_architecture_class_name(vllm_config.model_config),
        usage_context,
        extra_kvs={
            # Common configuration
            "dtype":
            str(vllm_config.model_config.dtype),
            "tensor_parallel_size":
            vllm_config.parallel_config.tensor_parallel_size,
            "block_size":
            vllm_config.cache_config.block_size,
            "gpu_memory_utilization":
            vllm_config.cache_config.gpu_memory_utilization,

            # Quantization
            "quantization":
            vllm_config.model_config.quantization,
            "kv_cache_dtype":
            str(vllm_config.cache_config.cache_dtype),

            # Feature flags
            "enable_lora":
            bool(vllm_config.lora_config),
            "enable_prompt_adapter":
            bool(vllm_config.prompt_adapter_config),
            "enable_prefix_caching":
            vllm_config.cache_config.enable_prefix_caching,
            "enforce_eager":
            vllm_config.model_config.enforce_eager,
            "disable_custom_all_reduce":
            vllm_config.parallel_config.disable_custom_all_reduce,
        })