uniproc_executor.py 7.21 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
from collections.abc import Callable
5
6
from concurrent.futures import Future, ThreadPoolExecutor
from functools import cached_property
7
from multiprocessing import Lock
8
from typing import Any
9

10
11
12
13
import torch
import torch.distributed as dist

import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.platforms import current_platform
16
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
17
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
18
from vllm.v1.executor.abstract import Executor
19
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
20
from vllm.v1.serial_utils import run_method
21
from vllm.v1.worker.worker_base import WorkerWrapperBase
22
23
24
25

logger = init_logger(__name__)


26
class UniProcExecutor(Executor):
27
    def _init_executor(self) -> None:
28
        """Initialize the worker and load the model."""
29
        self.driver_worker = WorkerWrapperBase(rpc_rank=0)
30
        distributed_init_method, rank, local_rank = self._distributed_args()
31
32
33
34
35
        kwargs = dict(
            vllm_config=self.vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
36
37
            is_driver_worker=True,
            shared_worker_lock=Lock(),
38
        )
39

40
        self.async_output_thread: ThreadPoolExecutor | None = None
41
42
        if self.max_concurrent_batches > 1:
            self.async_output_thread = ThreadPoolExecutor(
43
44
                max_workers=1, thread_name_prefix="WorkerAsyncOutput"
            )
45

46
        self.driver_worker.init_worker(all_kwargs=[kwargs])
47
48
49
50
51
        self.driver_worker.init_device()

        if envs.VLLM_ELASTIC_EP_SCALE_UP_LAUNCH:
            self.driver_worker.elastic_ep_execute("load_model")
        else:
52
            self.driver_worker.load_model()
53
        current_platform.update_block_size_for_backend(self.vllm_config)
54

55
56
    def _distributed_args(self) -> tuple[str, int, int]:
        """Return (distributed_init_method, rank, local_rank)."""
57
        distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
58
        # set local rank as the device index if specified
59
        device_info = self.vllm_config.device_config.device.__str__().split(":")
60
61
62
63
64
65
66
        local_rank = int(device_info[1]) if len(device_info) > 1 else 0
        return distributed_init_method, 0, local_rank

    @cached_property
    def max_concurrent_batches(self) -> int:
        return 2 if self.scheduler_config.async_scheduling else 1

67
    def collective_rpc(  # type: ignore[override]
68
        self,
69
70
        method: str | Callable,
        timeout: float | None = None,
71
        args: tuple = (),
72
        kwargs: dict | None = None,
73
        non_block: bool = False,
74
        single_value: bool = False,
75
    ) -> Any:
76
77
        if kwargs is None:
            kwargs = {}
78
79

        if not non_block:
80
81
            result = run_method(self.driver_worker, method, args, kwargs)
            return result if single_value else [result]
82
83
84
85
86

        try:
            result = run_method(self.driver_worker, method, args, kwargs)
            if isinstance(result, AsyncModelRunnerOutput):
                if (async_thread := self.async_output_thread) is not None:
87
88
89
90
91
92
93
                    if single_value:
                        return async_thread.submit(result.get_output)

                    def get_output_list() -> list[Any]:
                        return [result.get_output()]

                    return async_thread.submit(get_output_list)
94
95
                result = result.get_output()
            future = Future[Any]()
96
            future.set_result(result if single_value else [result])
97
98
99
        except Exception as e:
            future = Future[Any]()
            future.set_exception(e)
100
101
102
103
104
        return future

    def execute_model(  # type: ignore[override]
        self, scheduler_output: SchedulerOutput, non_block: bool = False
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
105
        output = self.collective_rpc(
106
107
108
109
110
            "execute_model",
            args=(scheduler_output,),
            non_block=non_block,
            single_value=True,
        )
111
112
113
114
115
        # In non-blocking mode, surface any exception as early as possible.
        if non_block and output.done():
            # Raise the exception in-line if the task failed.
            output.result()
        return output
116
117
118
119
120
121
122
123
124
125
126
127
128

    def sample_tokens(  # type: ignore[override]
        self, grammar_output: GrammarOutput | None, non_block: bool = False
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
        return self.collective_rpc(
            "sample_tokens",
            args=(grammar_output,),
            non_block=non_block,
            single_value=True,
        )

    def take_draft_token_ids(self) -> DraftTokenIds | None:
        return self.collective_rpc("take_draft_token_ids", single_value=True)
129
130
131
132
133
134

    def check_health(self) -> None:
        # UniProcExecutor will always be healthy as long as
        # it's running.
        return

135
136
137
138
    def shutdown(self) -> None:
        if worker := self.driver_worker:
            worker.shutdown()

139
140
141
142
    @classmethod
    def supports_async_scheduling(cls) -> bool:
        return True

143

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
class ExecutorWithExternalLauncher(UniProcExecutor):
    """An executor that uses external launchers to launch engines,
    specially designed for torchrun-compatible launchers, for
    offline inference with tensor parallelism.

    see https://github.com/vllm-project/vllm/issues/11400 for
    the motivation, and examples/offline_inference/torchrun_example.py
    for the usage example.

    The key idea: although it is tensor-parallel inference, we only
    create one worker per executor, users will launch multiple
    engines with torchrun-compatible launchers, and all these engines
    work together to process the same prompts. When scheduling is
    deterministic, all the engines will generate the same outputs,
    and they don't need to synchronize the states with each other.
    """
160

161
    def _init_executor(self) -> None:
162
        """Initialize the worker and load the model."""
163
164
165
166
        assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
            "To get deterministic execution, "
            "please set VLLM_ENABLE_V1_MULTIPROCESSING=0"
        )
167
168
169
        super()._init_executor()

    def _distributed_args(self) -> tuple[str, int, int]:
170
171
172
173
        # engines are launched in torchrun-compatible launchers
        # so we can use the env:// method.
        # required env vars:
        # - RANK
174
        # - LOCAL_RANK
175
176
177
178
        # - MASTER_ADDR
        # - MASTER_PORT
        distributed_init_method = "env://"
        rank = int(os.environ["RANK"])
179
        local_rank = int(os.environ["LOCAL_RANK"])
180
        return distributed_init_method, rank, local_rank
181

182
183
184
    def determine_available_memory(self) -> list[int]:  # in bytes
        # we need to get the min across all ranks.
        memory = super().determine_available_memory()
185
        from vllm.distributed.parallel_state import get_world_group
186

187
        cpu_group = get_world_group().cpu_group
188
189
190
        memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
        dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
        return [memory_tensor.item()]