llm_engine.py 11 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from collections.abc import Mapping
4
from copy import copy
5
from typing import Any, Callable, Optional, Union
6

7
8
from typing_extensions import TypeVar

9
import vllm.envs as envs
10
from vllm.config import ParallelConfig, VllmConfig
11
from vllm.distributed import stateless_destroy_torch_distributed_process_group
12
from vllm.engine.arg_utils import EngineArgs
13
from vllm.inputs import PromptType
14
15
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
16
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
17
from vllm.outputs import RequestOutput
18
19
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
20
from vllm.sampling_params import SamplingParams
21
from vllm.transformers_utils.tokenizer_group import (
22
    TokenizerGroup, init_tokenizer_from_configs)
23
from vllm.usage.usage_lib import UsageContext
24
from vllm.utils import Device
25
from vllm.v1.engine.core_client import EngineCoreClient
26
from vllm.v1.engine.output_processor import OutputProcessor
27
from vllm.v1.engine.parallel_sampling import ParentRequest
28
from vllm.v1.engine.processor import Processor
29
from vllm.v1.executor.abstract import Executor
30
from vllm.v1.metrics.loggers import StatLoggerFactory
31
32
33

logger = init_logger(__name__)

34
_R = TypeVar("_R", default=Any)
35

36
37

class LLMEngine:
38
    """Legacy LLMEngine for backwards compatibility."""
39
40
41

    def __init__(
        self,
42
        vllm_config: VllmConfig,
43
        executor_class: type[Executor],
44
45
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
46
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
47
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
48
        use_cached_outputs: bool = False,
49
        multiprocess_mode: bool = False,
50
    ) -> None:
51
52
53
54
55
56
57
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

58
59
60
61
62
        if stat_loggers is not None:
            raise NotImplementedError(
                "Passing StatLoggers to LLMEngine in V1 is not yet supported. "
                "Set VLLM_USE_V1=0 and file and issue on Github.")

63
        self.vllm_config = vllm_config
64
        self.model_config = vllm_config.model_config
65
        self.cache_config = vllm_config.cache_config
66

67
        # important: init dp group before init the engine_core
68
69
70
71
72
73
        # In the decoupled engine case this is handled in EngineCoreProc.
        parallel_config = vllm_config.parallel_config
        if not multiprocess_mode and parallel_config.data_parallel_size > 1:
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
74
75
        self.should_execute_dummy_batch = False

76
77
78
79
        # Tokenizer (+ ensure liveness if running in another process).
        self.tokenizer = init_tokenizer_from_configs(
            model_config=vllm_config.model_config,
            scheduler_config=vllm_config.scheduler_config,
80
            lora_config=vllm_config.lora_config)
81
82

        # Processor (convert Inputs --> EngineCoreRequests)
83
        self.processor = Processor(vllm_config=vllm_config,
84
85
                                   tokenizer=self.tokenizer,
                                   mm_registry=mm_registry)
86

87
88
89
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
                                                log_stats=False)
90
91
92
93
94

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
95
96
            vllm_config=vllm_config,
            executor_class=executor_class,
97
            log_stats=False,  # FIXME: implement
98
        )
99

100
101
102
103
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

104
105
106
107
108
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
109
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
110
111
112
113
114
115
116
117
118
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(vllm_config=vllm_config,
                   executor_class=Executor.get_class(vllm_config),
                   log_stats=(not disable_log_stats),
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)

119
120
121
122
123
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
124
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
125
        enable_multiprocessing: bool = False,
126
127
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
128

129
        # Create the engine configs.
130
        vllm_config = engine_args.create_engine_config(usage_context)
131
        executor_class = Executor.get_class(vllm_config)
132

133
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
134
135
136
137
138
139
140
141
142
143
144
145
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
        return cls(vllm_config=vllm_config,
                   executor_class=executor_class,
                   log_stats=not engine_args.disable_log_stats,
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=enable_multiprocessing)

    def get_num_unfinished_requests(self) -> int:
146
        return self.output_processor.get_num_unfinished_requests()
147
148

    def has_unfinished_requests(self) -> bool:
149
        has_unfinished = self.output_processor.has_unfinished_requests()
150
        if self.dp_group is None:
151
152
153
154
155
156
157
158
159
            return has_unfinished
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
            self.dp_group, has_unfinished)
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
160
161
162
163
164

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

165
    def abort_request(self, request_ids: list[str]) -> None:
166
167
        """Remove request_ids from EngineCore and Detokenizer."""

168
        request_ids = self.output_processor.abort_requests(request_ids)
169
170
        self.engine_core.abort_requests(request_ids)

171
172
173
174
175
176
177
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
178
        tokenization_kwargs: Optional[dict[str, Any]] = None,
179
180
181
182
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
    ) -> None:
183
        # Process raw inputs into the request.
184
185
        prompt_str, request = self.processor.process_inputs(
            request_id, prompt, params, arrival_time, lora_request,
186
187
            tokenization_kwargs, trace_headers, prompt_adapter_request,
            priority)
188

189
        n = params.n if isinstance(params, SamplingParams) else 1
190

191
192
        if n == 1:
            # Make a new RequestState and queue.
193
            self.output_processor.add_request(request, prompt_str, None, 0)
194
            # Add the request to EngineCore.
195
            self.engine_core.add_request(request)
196
197
198
199
200
201
202
203
204
205
206
            return

        # Fan out child requests (for n>1).
        parent_req = ParentRequest(request_id, params)
        for idx in range(n):
            request_id, params = parent_req.get_child_info(idx)
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
            child_request.sampling_params = params

            # Make a new RequestState and queue.
207
208
            self.output_processor.add_request(child_request, prompt_str,
                                              parent_req, idx)
209
210
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
211

212
    def step(self) -> list[RequestOutput]:
213

214
215
216
217
218
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

219
        # 1) Get EngineCoreOutput from the EngineCore.
220
        outputs = self.engine_core.get_output()
221

222
223
        # 2) Process EngineCoreOutputs.
        processed_outputs = self.output_processor.process_outputs(
224
            outputs.outputs)
225

226
227
        # 3) Abort any reqs that finished due to stop strings.
        self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
228

229
        return processed_outputs.request_outputs
230

231
232
233
    def get_vllm_config(self):
        return self.vllm_config

234
    def get_model_config(self):
235
        return self.model_config
236

237
    def start_profile(self):
238
        self.engine_core.profile(True)
239

240
    def stop_profile(self):
241
        self.engine_core.profile(False)
242

243
    def reset_prefix_cache(self, device: Optional[Device] = None):
244
245
        self.engine_core.reset_prefix_cache()

246
247
248
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

249
250
    def wake_up(self, tags: Optional[list[str]] = None):
        self.engine_core.wake_up(tags)
251

252
253
254
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

255
256
    def get_tokenizer_group(self) -> TokenizerGroup:
        if self.tokenizer is None:
257
258
259
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")

260
        return self.tokenizer
261
262
263
264
265
266
267
268
269

    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

270
    def list_loras(self) -> set[int]:
271
272
273
274
275
276
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
277

278
279
280
281
282
283
284
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

285
286
287
    def __del__(self):
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)