uniproc_executor.py 6.51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
from multiprocessing import Lock
6
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
7

8
9
10
11
import torch
import torch.distributed as dist

import vllm.envs as envs
12
13
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
14
15
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
16
17
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
                        run_method)
18
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
19
from vllm.v1.executor.utils import get_and_update_mm_cache
20
21
22
23
24
25
26
27
28
29
30
31
32
from vllm.worker.worker_base import WorkerWrapperBase

logger = init_logger(__name__)


class UniProcExecutor(ExecutorBase):

    uses_ray: bool = False

    def _init_executor(self) -> None:
        """Initialize the worker and load the model.
        """
        self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
33
                                               rpc_rank=0)
34
35
36
        distributed_init_method = get_distributed_init_method(
            get_ip(), get_open_port())
        local_rank = 0
37
38
39
40
41
        # set local rank as the device index if specified
        device_info = self.vllm_config.device_config.device.__str__().split(
            ":")
        if len(device_info) > 1:
            local_rank = int(device_info[1])
42
        rank = 0
43
        is_driver_worker = True
44
45
46
47
48
        kwargs = dict(
            vllm_config=self.vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
49
            is_driver_worker=is_driver_worker,
50
        )
51
52
        self.mm_receiver_cache = worker_receiver_cache_from_config(
            self.vllm_config, MULTIMODAL_REGISTRY, Lock())
53
54
55
56
57
        self.collective_rpc("init_worker", args=([kwargs], ))
        self.collective_rpc("init_device")
        self.collective_rpc("load_model")

    def collective_rpc(self,
58
                       method: Union[str, Callable],
59
60
61
62
63
                       timeout: Optional[float] = None,
                       args: Tuple = (),
                       kwargs: Optional[Dict] = None) -> List[Any]:
        if kwargs is None:
            kwargs = {}
64
65
        if self.mm_receiver_cache is not None and method == "execute_model":
            get_and_update_mm_cache(self.mm_receiver_cache, args)
66
        answer = run_method(self.driver_worker, method, args, kwargs)
67
68
69
70
71
72
73
        return [answer]

    def check_health(self) -> None:
        # UniProcExecutor will always be healthy as long as
        # it's running.
        return

74
75
76
77
78
79
80
81
    def reinitialize_distributed(
            self, reconfig_request: ReconfigureDistributedRequest) -> None:
        self.driver_worker.reinitialize_distributed(reconfig_request)
        if reconfig_request.new_data_parallel_rank == \
        ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
            self.shutdown()
        return

82
83
84
85
    def shutdown(self) -> None:
        if worker := self.driver_worker:
            worker.shutdown()

86
87

UniProcExecutorAsync = UniProcExecutor
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114


class ExecutorWithExternalLauncher(UniProcExecutor):
    """An executor that uses external launchers to launch engines,
    specially designed for torchrun-compatible launchers, for
    offline inference with tensor parallelism.

    see https://github.com/vllm-project/vllm/issues/11400 for
    the motivation, and examples/offline_inference/torchrun_example.py
    for the usage example.

    The key idea: although it is tensor-parallel inference, we only
    create one worker per executor, users will launch multiple
    engines with torchrun-compatible launchers, and all these engines
    work together to process the same prompts. When scheduling is
    deterministic, all the engines will generate the same outputs,
    and they don't need to synchronize the states with each other.
    """
    uses_ray: bool = False

    def _init_executor(self) -> None:
        """Initialize the worker and load the model.
        """
        assert self.vllm_config.scheduler_config.delay_factor == 0.0, \
            ("ExecutorWithExternalLauncher needs deterministic "
            "execution, so it"
            "does not support delay_factor in scheduling")
115
116
117
118
        if envs.VLLM_USE_V1:
            assert not envs.VLLM_ENABLE_V1_MULTIPROCESSING, \
            ("To get deterministic execution in V1, "
            "please set VLLM_ENABLE_V1_MULTIPROCESSING=0")
119
120
121
122
123
124
        self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config,
                                               rpc_rank=0)
        # engines are launched in torchrun-compatible launchers
        # so we can use the env:// method.
        # required env vars:
        # - RANK
125
        # - LOCAL_RANK
126
127
128
129
        # - MASTER_ADDR
        # - MASTER_PORT
        distributed_init_method = "env://"
        rank = int(os.environ["RANK"])
130
        local_rank = int(os.environ["LOCAL_RANK"])
131
132
133
134
135
136
137
138
        is_driver_worker = True
        kwargs = dict(
            vllm_config=self.vllm_config,
            local_rank=local_rank,
            rank=rank,
            distributed_init_method=distributed_init_method,
            is_driver_worker=is_driver_worker,
        )
139
140
        self.mm_receiver_cache = worker_receiver_cache_from_config(
            self.vllm_config, MULTIMODAL_REGISTRY, Lock())
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        self.collective_rpc("init_worker", args=([kwargs], ))
        self.collective_rpc("init_device")
        self.collective_rpc("load_model")

    def determine_num_available_blocks(self) -> Tuple[int, int]:
        """
        Determine the number of available KV blocks.
        Add an additional all_reduce to get the min across all ranks.
        Note that even if we have the same `gpu_memory_utilization` and 
        `swap_space`, the available memory in every rank might still 
        differ because NCCL can take different amounts of memory in 
        different ranks. Therefore, it is necessary to test if all ranks 
        agree on the same KV cache configuration.
        """
        a, b = super().determine_num_available_blocks()
        from vllm.distributed.parallel_state import get_world_group
        cpu_group = get_world_group().cpu_group
        a_tensor = torch.tensor([a], device="cpu", dtype=torch.int64)
        b_tensor = torch.tensor([b], device="cpu", dtype=torch.int64)
        dist.all_reduce(a_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
        dist.all_reduce(b_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
        return a_tensor.item(), b_tensor.item()