uniproc_executor.py 6.51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
from collections.abc import Callable
5
6
from concurrent.futures import Future, ThreadPoolExecutor
from functools import cached_property
7
from multiprocessing import Lock
8
from typing import Any
9

10
11
12
13
import torch
import torch.distributed as dist

import vllm.envs as envs
14
15
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
16
17
from vllm.utils import run_method
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
18
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
19
from vllm.v1.outputs import AsyncModelRunnerOutput
20
from vllm.v1.worker.worker_base import WorkerWrapperBase
21
22
23
24
25
26
27
28

logger = init_logger(__name__)


class UniProcExecutor(ExecutorBase):
    uses_ray: bool = False

    def _init_executor(self) -> None:
29
30
        """Initialize the worker and load the model."""
        self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, rpc_rank=0)
31
        distributed_init_method, rank, local_rank = self._distributed_args()
32
33
34
35
36
        kwargs = dict(
            vllm_config=self.vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
37
38
            is_driver_worker=True,
            shared_worker_lock=Lock(),
39
        )
40

41
        self.async_output_thread: ThreadPoolExecutor | None = None
42
43
        if self.max_concurrent_batches > 1:
            self.async_output_thread = ThreadPoolExecutor(
44
45
                max_workers=1, thread_name_prefix="WorkerAsyncOutput"
            )
46

47
        self.collective_rpc("init_worker", args=([kwargs],))
48
49
50
        self.collective_rpc("init_device")
        self.collective_rpc("load_model")

51
52
    def _distributed_args(self) -> tuple[str, int, int]:
        """Return (distributed_init_method, rank, local_rank)."""
53
        distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
54
        # set local rank as the device index if specified
55
        device_info = self.vllm_config.device_config.device.__str__().split(":")
56
57
58
59
60
61
62
        local_rank = int(device_info[1]) if len(device_info) > 1 else 0
        return distributed_init_method, 0, local_rank

    @cached_property
    def max_concurrent_batches(self) -> int:
        return 2 if self.scheduler_config.async_scheduling else 1

63
64
    def collective_rpc(
        self,
65
66
        method: str | Callable,
        timeout: float | None = None,
67
        args: tuple = (),
68
        kwargs: dict | None = None,
69
        non_block: bool = False,
70
    ) -> list[Any]:
71
72
        if kwargs is None:
            kwargs = {}
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

        if not non_block:
            return [run_method(self.driver_worker, method, args, kwargs)]

        try:
            result = run_method(self.driver_worker, method, args, kwargs)
            if isinstance(result, AsyncModelRunnerOutput):
                if (async_thread := self.async_output_thread) is not None:
                    return [async_thread.submit(result.get_output)]
                result = result.get_output()
            future = Future[Any]()
            future.set_result(result)
        except Exception as e:
            future = Future[Any]()
            future.set_exception(e)
        return [future]
89
90
91
92
93
94

    def check_health(self) -> None:
        # UniProcExecutor will always be healthy as long as
        # it's running.
        return

95
    def reinitialize_distributed(
96
97
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
98
        self.driver_worker.reinitialize_distributed(reconfig_request)
99
100
101
102
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
103
104
105
            self.shutdown()
        return

106
107
108
109
    def shutdown(self) -> None:
        if worker := self.driver_worker:
            worker.shutdown()

110
111

UniProcExecutorAsync = UniProcExecutor
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129


class ExecutorWithExternalLauncher(UniProcExecutor):
    """An executor that uses external launchers to launch engines,
    specially designed for torchrun-compatible launchers, for
    offline inference with tensor parallelism.

    see https://github.com/vllm-project/vllm/issues/11400 for
    the motivation, and examples/offline_inference/torchrun_example.py
    for the usage example.

    The key idea: although it is tensor-parallel inference, we only
    create one worker per executor, users will launch multiple
    engines with torchrun-compatible launchers, and all these engines
    work together to process the same prompts. When scheduling is
    deterministic, all the engines will generate the same outputs,
    and they don't need to synchronize the states with each other.
    """
130

131
132
133
    uses_ray: bool = False

    def _init_executor(self) -> None:
134
        """Initialize the worker and load the model."""
135
        if envs.VLLM_USE_V1:
136
137
138
139
            assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
                "To get deterministic execution in V1, "
                "please set VLLM_ENABLE_V1_MULTIPROCESSING=0"
            )
140
141
142
        super()._init_executor()

    def _distributed_args(self) -> tuple[str, int, int]:
143
144
145
146
        # engines are launched in torchrun-compatible launchers
        # so we can use the env:// method.
        # required env vars:
        # - RANK
147
        # - LOCAL_RANK
148
149
150
151
        # - MASTER_ADDR
        # - MASTER_PORT
        distributed_init_method = "env://"
        rank = int(os.environ["RANK"])
152
        local_rank = int(os.environ["LOCAL_RANK"])
153
        return distributed_init_method, rank, local_rank
154

155
    def determine_num_available_blocks(self) -> tuple[int, int]:
156
157
158
        """
        Determine the number of available KV blocks.
        Add an additional all_reduce to get the min across all ranks.
159
160
161
162
        Note that even if we have the same `gpu_memory_utilization` and
        `swap_space`, the available memory in every rank might still
        differ because NCCL can take different amounts of memory in
        different ranks. Therefore, it is necessary to test if all ranks
163
164
165
166
        agree on the same KV cache configuration.
        """
        a, b = super().determine_num_available_blocks()
        from vllm.distributed.parallel_state import get_world_group
167

168
169
170
171
172
173
        cpu_group = get_world_group().cpu_group
        a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64)
        b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64)
        dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
        dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
        return a_tensor.item(), b_tensor.item()