"tests/kernels/test_blocksparse_attention.py" did not exist on "1951f47847aae1746f318435dfb746d22f73b4ed"
uniproc_executor.py 7.07 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
from collections.abc import Callable
5
6
from concurrent.futures import Future, ThreadPoolExecutor
from functools import cached_property
7
from multiprocessing import Lock
8
from typing import Any
9

10
11
12
13
import torch
import torch.distributed as dist

import vllm.envs as envs
14
from vllm.logger import init_logger
15
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
16
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
17
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
18
from vllm.v1.executor.abstract import Executor
19
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
20
from vllm.v1.serial_utils import run_method
21
from vllm.v1.worker.worker_base import WorkerWrapperBase
22
23
24
25

logger = init_logger(__name__)


26
class UniProcExecutor(Executor):
27
    def _init_executor(self) -> None:
28
        """Initialize the worker and load the model."""
29
        self.driver_worker = WorkerWrapperBase(rpc_rank=0)
30
        distributed_init_method, rank, local_rank = self._distributed_args()
31
32
33
34
35
        kwargs = dict(
            vllm_config=self.vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
36
37
            is_driver_worker=True,
            shared_worker_lock=Lock(),
38
        )
39

40
        self.async_output_thread: ThreadPoolExecutor | None = None
41
42
        if self.max_concurrent_batches > 1:
            self.async_output_thread = ThreadPoolExecutor(
43
44
                max_workers=1, thread_name_prefix="WorkerAsyncOutput"
            )
45

46
47
48
        self.driver_worker.init_worker(all_kwargs=[kwargs])
        self.driver_worker.init_device()
        self.driver_worker.load_model()
49

50
51
    def _distributed_args(self) -> tuple[str, int, int]:
        """Return (distributed_init_method, rank, local_rank)."""
52
        distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
53
        # set local rank as the device index if specified
54
        device_info = self.vllm_config.device_config.device.__str__().split(":")
55
56
57
58
59
60
61
        local_rank = int(device_info[1]) if len(device_info) > 1 else 0
        return distributed_init_method, 0, local_rank

    @cached_property
    def max_concurrent_batches(self) -> int:
        return 2 if self.scheduler_config.async_scheduling else 1

62
    def collective_rpc(  # type: ignore[override]
63
        self,
64
65
        method: str | Callable,
        timeout: float | None = None,
66
        args: tuple = (),
67
        kwargs: dict | None = None,
68
        non_block: bool = False,
69
        single_value: bool = False,
70
    ) -> Any:
71
72
        if kwargs is None:
            kwargs = {}
73
74

        if not non_block:
75
76
            result = run_method(self.driver_worker, method, args, kwargs)
            return result if single_value else [result]
77
78
79
80
81

        try:
            result = run_method(self.driver_worker, method, args, kwargs)
            if isinstance(result, AsyncModelRunnerOutput):
                if (async_thread := self.async_output_thread) is not None:
82
83
84
85
86
87
88
                    if single_value:
                        return async_thread.submit(result.get_output)

                    def get_output_list() -> list[Any]:
                        return [result.get_output()]

                    return async_thread.submit(get_output_list)
89
90
                result = result.get_output()
            future = Future[Any]()
91
            future.set_result(result if single_value else [result])
92
93
94
        except Exception as e:
            future = Future[Any]()
            future.set_exception(e)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        return future

    def execute_model(  # type: ignore[override]
        self, scheduler_output: SchedulerOutput, non_block: bool = False
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
        return self.collective_rpc(
            "execute_model",
            args=(scheduler_output,),
            non_block=non_block,
            single_value=True,
        )

    def sample_tokens(  # type: ignore[override]
        self, grammar_output: GrammarOutput | None, non_block: bool = False
    ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
        return self.collective_rpc(
            "sample_tokens",
            args=(grammar_output,),
            non_block=non_block,
            single_value=True,
        )

    def take_draft_token_ids(self) -> DraftTokenIds | None:
        return self.collective_rpc("take_draft_token_ids", single_value=True)
119
120
121
122
123
124

    def check_health(self) -> None:
        # UniProcExecutor will always be healthy as long as
        # it's running.
        return

125
    def reinitialize_distributed(
126
127
        self, reconfig_request: ReconfigureDistributedRequest
    ) -> None:
128
        self.driver_worker.reinitialize_distributed(reconfig_request)
129
130
131
132
        if (
            reconfig_request.new_data_parallel_rank
            == ReconfigureRankType.SHUTDOWN_CURRENT_RANK
        ):
133
134
            self.shutdown()

135
136
137
138
    def shutdown(self) -> None:
        if worker := self.driver_worker:
            worker.shutdown()

139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class ExecutorWithExternalLauncher(UniProcExecutor):
    """An executor that uses external launchers to launch engines,
    specially designed for torchrun-compatible launchers, for
    offline inference with tensor parallelism.

    see https://github.com/vllm-project/vllm/issues/11400 for
    the motivation, and examples/offline_inference/torchrun_example.py
    for the usage example.

    The key idea: although it is tensor-parallel inference, we only
    create one worker per executor, users will launch multiple
    engines with torchrun-compatible launchers, and all these engines
    work together to process the same prompts. When scheduling is
    deterministic, all the engines will generate the same outputs,
    and they don't need to synchronize the states with each other.
    """
156

157
    def _init_executor(self) -> None:
158
        """Initialize the worker and load the model."""
159
160
161
162
        assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, (
            "To get deterministic execution, "
            "please set VLLM_ENABLE_V1_MULTIPROCESSING=0"
        )
163
164
165
        super()._init_executor()

    def _distributed_args(self) -> tuple[str, int, int]:
166
167
168
169
        # engines are launched in torchrun-compatible launchers
        # so we can use the env:// method.
        # required env vars:
        # - RANK
170
        # - LOCAL_RANK
171
172
173
174
        # - MASTER_ADDR
        # - MASTER_PORT
        distributed_init_method = "env://"
        rank = int(os.environ["RANK"])
175
        local_rank = int(os.environ["LOCAL_RANK"])
176
        return distributed_init_method, rank, local_rank
177

178
179
180
    def determine_available_memory(self) -> list[int]:  # in bytes
        # we need to get the min across all ranks.
        memory = super().determine_available_memory()
181
        from vllm.distributed.parallel_state import get_world_group
182

183
        cpu_group = get_world_group().cpu_group
184
185
186
        memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
        dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
        return [memory_tensor.item()]