llm_engine.py 13.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping
5
from copy import copy
6
from typing import Any, Callable, Optional, Union
7

8
import torch.nn as nn
9
10
from typing_extensions import TypeVar

11
import vllm.envs as envs
12
from vllm.config import ParallelConfig, VllmConfig
13
from vllm.distributed import stateless_destroy_torch_distributed_process_group
14
from vllm.distributed.parallel_state import get_dp_group
15
from vllm.engine.arg_utils import EngineArgs
16
from vllm.inputs import PromptType
17
18
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
19
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
20
from vllm.outputs import PoolingRequestOutput, RequestOutput
21
from vllm.pooling_params import PoolingParams
22
from vllm.sampling_params import SamplingParams
23
from vllm.tasks import SupportedTask
24
from vllm.tracing import init_tracer
25
26
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               init_tokenizer_from_configs)
27
from vllm.usage.usage_lib import UsageContext
28
from vllm.utils import Device
29
from vllm.v1.engine.core_client import EngineCoreClient
30
from vllm.v1.engine.output_processor import OutputProcessor
31
from vllm.v1.engine.parallel_sampling import ParentRequest
32
from vllm.v1.engine.processor import Processor
33
from vllm.v1.executor.abstract import Executor
34
35
36
37
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
                                     StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
38
from vllm.v1.worker.worker_base import WorkerBase
39
40
41

logger = init_logger(__name__)

42
_R = TypeVar("_R", default=Any)
43

44
45

class LLMEngine:
46
    """Legacy LLMEngine for backwards compatibility."""
47
48
49

    def __init__(
        self,
50
        vllm_config: VllmConfig,
51
        executor_class: type[Executor],
52
53
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
54
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
55
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
56
        use_cached_outputs: bool = False,
57
        multiprocess_mode: bool = False,
58
    ) -> None:
59
60
61
62
63
64
65
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

66
67
68
69
70
        if stat_loggers is not None:
            raise NotImplementedError(
                "Passing StatLoggers to LLMEngine in V1 is not yet supported. "
                "Set VLLM_USE_V1=0 and file and issue on Github.")

71
        self.vllm_config = vllm_config
72
        self.observability_config = vllm_config.observability_config
73
        self.model_config = vllm_config.model_config
74
        self.cache_config = vllm_config.cache_config
75

76
77
78
79
80
        self.log_stats = log_stats
        self.stat_logger: Optional[StatLoggerBase] = None
        if self.log_stats:
            self.stat_logger = PrometheusStatLogger(vllm_config)

81
82
83
84
85
        executor_backend = (
            self.vllm_config.parallel_config.distributed_executor_backend)
        parallel_config = vllm_config.parallel_config
        self.external_launcher_dp = (parallel_config.data_parallel_size > 1 and
                                     executor_backend == "external_launcher")
86
        # important: init dp group before init the engine_core
87
        # In the decoupled engine case this is handled in EngineCoreProc.
88
89
        if not multiprocess_mode and parallel_config.data_parallel_size > 1 \
            and not self.external_launcher_dp:
90
91
92
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
93
94
        self.should_execute_dummy_batch = False

95
96
97
98
99
        if self.model_config.skip_tokenizer_init:
            self.tokenizer = None
        else:
            # Tokenizer (+ ensure liveness if running in another process).
            self.tokenizer = init_tokenizer_from_configs(
100
                model_config=vllm_config.model_config)
101
102

        # Processor (convert Inputs --> EngineCoreRequests)
103
        self.processor = Processor(vllm_config=vllm_config,
104
105
                                   tokenizer=self.tokenizer,
                                   mm_registry=mm_registry)
106

107
108
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
109
                                                log_stats=self.log_stats)
110
111
112
113
114
        if self.observability_config.otlp_traces_endpoint is not None:
            tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)
            self.output_processor.tracer = tracer
115
116
117
118
119

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
120
121
            vllm_config=vllm_config,
            executor_class=executor_class,
122
            log_stats=self.log_stats,
123
        )
124

125
126
127
128
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

129
130
131
132
133
        if self.external_launcher_dp:
            # If we use DP in external launcher mode, we reuse the
            # existing DP group used for data communication.
            self.dp_group = get_dp_group().cpu_group

134
135
136
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

137
138
139
140
141
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
142
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
143
144
145
146
147
148
149
150
151
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(vllm_config=vllm_config,
                   executor_class=Executor.get_class(vllm_config),
                   log_stats=(not disable_log_stats),
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)

152
153
154
155
156
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
157
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
158
        enable_multiprocessing: bool = False,
159
160
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
161

162
        # Create the engine configs.
163
        vllm_config = engine_args.create_engine_config(usage_context)
164
        executor_class = Executor.get_class(vllm_config)
165

166
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
167
168
169
170
171
172
173
174
175
176
177
178
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
        return cls(vllm_config=vllm_config,
                   executor_class=executor_class,
                   log_stats=not engine_args.disable_log_stats,
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=enable_multiprocessing)

    def get_num_unfinished_requests(self) -> int:
179
        return self.output_processor.get_num_unfinished_requests()
180
181

    def has_unfinished_requests(self) -> bool:
182
        has_unfinished = self.output_processor.has_unfinished_requests()
183
        if self.dp_group is None:
184
            return has_unfinished or self.engine_core.dp_engines_running()
185
186
187
188
189
190
191
192
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
            self.dp_group, has_unfinished)
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
193
194
195
196
197

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

198
199
200
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.engine_core.get_supported_tasks()

201
    def abort_request(self, request_ids: list[str]) -> None:
202
203
        """Remove request_ids from EngineCore and Detokenizer."""

204
        request_ids = self.output_processor.abort_requests(request_ids)
205
206
        self.engine_core.abort_requests(request_ids)

207
208
209
210
211
212
213
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
214
        tokenization_kwargs: Optional[dict[str, Any]] = None,
215
216
217
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
    ) -> None:
218
219
220
221
222
        # Validate the request_id type.
        if not isinstance(request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request_id)}")

223
        # Process raw inputs into the request.
224
225
        prompt_str, request = self.processor.process_inputs(
            request_id, prompt, params, arrival_time, lora_request,
226
            tokenization_kwargs, trace_headers, priority)
227

228
        n = params.n if isinstance(params, SamplingParams) else 1
229

230
231
        if n == 1:
            # Make a new RequestState and queue.
232
            self.output_processor.add_request(request, prompt_str, None, 0)
233
            # Add the request to EngineCore.
234
            self.engine_core.add_request(request)
235
236
237
238
239
240
241
242
243
244
245
            return

        # Fan out child requests (for n>1).
        parent_req = ParentRequest(request_id, params)
        for idx in range(n):
            request_id, params = parent_req.get_child_info(idx)
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
            child_request.sampling_params = params

            # Make a new RequestState and queue.
246
247
            self.output_processor.add_request(child_request, prompt_str,
                                              parent_req, idx)
248
249
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
250

251
    def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:
252

253
254
255
256
257
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

258
        # 1) Get EngineCoreOutput from the EngineCore.
259
        outputs = self.engine_core.get_output()
260

261
        # 2) Process EngineCoreOutputs.
262
        iteration_stats = IterationStats() if self.log_stats else None
263
        processed_outputs = self.output_processor.process_outputs(
264
265
266
            outputs.outputs,
            engine_core_timestamp=outputs.timestamp,
            iteration_stats=iteration_stats)
267

268
269
        # 3) Abort any reqs that finished due to stop strings.
        self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
270

271
272
273
274
275
276
        # 4) Record stats
        if self.stat_logger is not None:
            assert outputs.scheduler_stats is not None
            self.stat_logger.record(scheduler_stats=outputs.scheduler_stats,
                                    iteration_stats=iteration_stats)

277
        return processed_outputs.request_outputs
278

279
280
281
    def get_vllm_config(self):
        return self.vllm_config

282
    def get_model_config(self):
283
        return self.model_config
284

285
    def start_profile(self):
286
        self.engine_core.profile(True)
287

288
    def stop_profile(self):
289
        self.engine_core.profile(False)
290

291
    def reset_mm_cache(self):
292
        self.processor.clear_cache()
293
294
        self.engine_core.reset_mm_cache()

295
    def reset_prefix_cache(self, device: Optional[Device] = None):
296
297
        self.engine_core.reset_prefix_cache()

298
299
300
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

301
302
    def wake_up(self, tags: Optional[list[str]] = None):
        self.engine_core.wake_up(tags)
303

304
305
306
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

307
308
309
310
    def get_metrics(self) -> list[Metric]:
        assert self.log_stats, "Stat logging disabled"
        return get_metrics_snapshot()

311
    def get_tokenizer(self) -> AnyTokenizer:
312
        if self.tokenizer is None:
313
314
315
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")

316
        return self.tokenizer
317
318
319
320
321
322
323
324
325

    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

326
    def list_loras(self) -> set[int]:
327
328
329
330
331
332
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
333

334
    def collective_rpc(self,
335
                       method: Union[str, Callable[[WorkerBase], _R]],
336
337
338
339
340
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

341
342
343
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        return self.collective_rpc("apply_model", args=(func, ))

344
    def __del__(self):
345
346
        if dp_group := getattr(self, "dp_group",
                               None) and not self.external_launcher_dp:
347
            stateless_destroy_torch_distributed_process_group(dp_group)