llm_engine.py 15.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Callable, Mapping
6
from copy import copy
7
from typing import Any, cast
8

9
import torch.nn as nn
10
from typing_extensions import TypeVar
11

12
import vllm.envs as envs
13
from vllm.config import ParallelConfig, VllmConfig
14
from vllm.distributed import stateless_destroy_torch_distributed_process_group
15
from vllm.distributed.parallel_state import get_dp_group
16
from vllm.engine.arg_utils import EngineArgs
17
from vllm.inputs import PromptType
18
19
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
20
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
21
from vllm.outputs import PoolingRequestOutput, RequestOutput
22
from vllm.plugins.io_processors import get_io_processor
23
from vllm.pooling_params import PoolingParams
24
from vllm.renderers import BaseRenderer
25
from vllm.sampling_params import SamplingParams
26
from vllm.tasks import SupportedTask
27
from vllm.tokenizers import TokenizerLike
28
from vllm.tracing import init_tracer
29
from vllm.usage.usage_lib import UsageContext
30
from vllm.v1.engine import EngineCoreRequest
31
from vllm.v1.engine.core_client import EngineCoreClient
32
from vllm.v1.engine.input_processor import InputProcessor
33
from vllm.v1.engine.output_processor import OutputProcessor
34
from vllm.v1.engine.parallel_sampling import ParentRequest
35
from vllm.v1.executor import Executor
36
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
37
38
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
39
from vllm.v1.utils import record_function_or_nullcontext
40
from vllm.v1.worker.worker_base import WorkerBase
41
42
43

logger = init_logger(__name__)

44
_R = TypeVar("_R", default=Any)
45

46
47

class LLMEngine:
48
    """Legacy LLMEngine for backwards compatibility."""
49
50
51

    def __init__(
        self,
52
        vllm_config: VllmConfig,
53
        executor_class: type[Executor],
54
        log_stats: bool,
55
        aggregate_engine_logging: bool = False,
56
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
57
        stat_loggers: list[StatLoggerFactory] | None = None,
58
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
59
        use_cached_outputs: bool = False,
60
        multiprocess_mode: bool = False,
61
    ) -> None:
62
        self.vllm_config = vllm_config
63
        self.observability_config = vllm_config.observability_config
64
        self.model_config = vllm_config.model_config
65
        self.cache_config = vllm_config.cache_config
66

67
68
        self.log_stats = log_stats

69
        parallel_config = vllm_config.parallel_config
70
71
        executor_backend = parallel_config.distributed_executor_backend

72
73
74
75
        self.external_launcher_dp = (
            parallel_config.data_parallel_size > 1
            and executor_backend == "external_launcher"
        )
76
        # important: init dp group before init the engine_core
77
        # In the decoupled engine case this is handled in EngineCoreProc.
78
79
80
81
82
        if (
            not multiprocess_mode
            and parallel_config.data_parallel_size > 1
            and not self.external_launcher_dp
        ):
83
84
85
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
86
87
        self.should_execute_dummy_batch = False

88
        self.input_processor = InputProcessor(self.vllm_config)
89
90
        self.io_processor = get_io_processor(
            self.vllm_config,
91
            self.model_config.io_processor_plugin,
92
        )
93

94
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
95
        self.output_processor = OutputProcessor(
96
97
98
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
99
        )
100
101
102
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
103
            self.output_processor.tracer = tracer
104
105
106
107
108

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
109
110
            vllm_config=vllm_config,
            executor_class=executor_class,
111
            log_stats=self.log_stats,
112
        )
113

114
        self.logger_manager: StatLoggerManager | None = None
115
116
117
118
119
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
                custom_stat_loggers=stat_loggers,
                enable_default_loggers=log_stats,
120
                aggregate_engine_logging=aggregate_engine_logging,
121
122
123
            )
            self.logger_manager.log_engine_initialized()

124
125
126
127
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

128
129
130
131
132
        if self.external_launcher_dp:
            # If we use DP in external launcher mode, we reuse the
            # existing DP group used for data communication.
            self.dp_group = get_dp_group().cpu_group

133
134
135
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

136
137
138
139
140
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
141
        stat_loggers: list[StatLoggerFactory] | None = None,
142
143
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
144
145
146
147
148
149
150
151
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING,
        )
152

153
154
155
156
157
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
158
        stat_loggers: list[StatLoggerFactory] | None = None,
159
        enable_multiprocessing: bool = False,
160
161
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
162

163
        # Create the engine configs.
164
        vllm_config = engine_args.create_engine_config(usage_context)
165
        executor_class = Executor.get_class(vllm_config)
166

167
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
168
169
170
171
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
172
173
174
175
176
177
178
179
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=enable_multiprocessing,
        )
180
181

    def get_num_unfinished_requests(self) -> int:
182
        return self.output_processor.get_num_unfinished_requests()
183
184

    def has_unfinished_requests(self) -> bool:
185
        has_unfinished = self.output_processor.has_unfinished_requests()
186
        if self.dp_group is None:
187
            return has_unfinished or self.engine_core.dp_engines_running()
188
189
190
191
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
192
193
            self.dp_group, has_unfinished
        )
194
195
196
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
197
198
199
200
201

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

202
203
204
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.engine_core.get_supported_tasks()

205
    def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
206
207
        """Remove request_ids from EngineCore and Detokenizer."""

208
        request_ids = self.output_processor.abort_requests(request_ids, internal)
209
210
        self.engine_core.abort_requests(request_ids)

211
212
213
    def add_request(
        self,
        request_id: str,
214
215
216
217
218
219
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
220
        priority: int = 0,
221
        prompt_text: str | None = None,
222
    ) -> None:
223
224
        # Validate the request_id type.
        if not isinstance(request_id, str):
225
            raise TypeError(f"request_id must be a string, got {type(request_id)}")
226

227
        # Process raw inputs into the request.
228
229
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
230
231
232
233
234
235
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
236
237
        else:
            assert prompt_text is None
238
            request = self.input_processor.process_inputs(
239
240
241
242
243
244
245
246
247
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
            )
248
249
250
251
            if isinstance(prompt, str):
                prompt_text = prompt
            elif isinstance(prompt, Mapping):
                prompt_text = cast(str | None, prompt.get("prompt"))
252

253
254
        self.input_processor.assign_request_id(request)

255
256
257
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

258
        n = params.n if isinstance(params, SamplingParams) else 1
259

260
261
        if n == 1:
            # Make a new RequestState and queue.
262
            self.output_processor.add_request(request, prompt_text, None, 0)
263
            # Add the request to EngineCore.
264
            self.engine_core.add_request(request)
265
266
267
            return

        # Fan out child requests (for n>1).
268
        parent_req = ParentRequest(request)
269
        for idx in range(n):
270
            request_id, child_params = parent_req.get_child_info(idx)
271
272
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
273
            child_request.sampling_params = child_params
274
275

            # Make a new RequestState and queue.
276
277
278
            self.output_processor.add_request(
                child_request, prompt_text, parent_req, idx
            )
279
280
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
281

282
    def step(self) -> list[RequestOutput | PoolingRequestOutput]:
283
284
285
286
287
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

288
        # 1) Get EngineCoreOutput from the EngineCore.
289
        with record_function_or_nullcontext("llm_engine step: get_output"):
290
            outputs = self.engine_core.get_output()
291

292
        # 2) Process EngineCoreOutputs.
293
        with record_function_or_nullcontext("llm_engine step: process_outputs"):
294
295
296
297
298
299
300
            iteration_stats = IterationStats() if self.log_stats else None
            processed_outputs = self.output_processor.process_outputs(
                outputs.outputs,
                engine_core_timestamp=outputs.timestamp,
                iteration_stats=iteration_stats,
            )
            self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
301

302
        # 3) Abort any reqs that finished due to stop strings.
303
        with record_function_or_nullcontext("llm_engine step: abort_requests"):
304
            self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
305

306
        # 4) Record stats
307
        with record_function_or_nullcontext("llm_engine step: record_stats"):
308
309
310
311
            if self.logger_manager is not None and outputs.scheduler_stats is not None:
                self.logger_manager.record(
                    scheduler_stats=outputs.scheduler_stats,
                    iteration_stats=iteration_stats,
312
                    mm_cache_stats=self.input_processor.stat_mm_cache(),
313
314
                )
                self.do_log_stats_with_interval()
315

316
        return processed_outputs.request_outputs
317

318
    def start_profile(self):
319
        self.engine_core.profile(True)
320

321
    def stop_profile(self):
322
        self.engine_core.profile(False)
323

324
    def reset_mm_cache(self):
325
        self.input_processor.clear_mm_cache()
326
327
        self.engine_core.reset_mm_cache()

328
329
330
331
332
333
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.engine_core.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
334

335
336
337
338
339
340
341
342
    def reset_encoder_cache(self) -> None:
        """Reset the encoder cache to invalidate all cached encoder outputs.

        This should be called when model weights are updated to ensure
        stale vision embeddings computed with old weights are not reused.
        """
        self.engine_core.reset_encoder_cache()

343
344
345
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

346
347
348
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

349
    def wake_up(self, tags: list[str] | None = None):
350
        self.engine_core.wake_up(tags)
351

352
353
354
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

355
356
357
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

358
359
360
361
    def get_metrics(self) -> list[Metric]:
        assert self.log_stats, "Stat logging disabled"
        return get_metrics_snapshot()

362
    @property
363
    def tokenizer(self) -> TokenizerLike | None:
364
        return self.input_processor.tokenizer
365

366
    def get_tokenizer(self) -> TokenizerLike:
367
        return self.input_processor.get_tokenizer()
368

369
    @property
370
    def renderer(self) -> BaseRenderer:
371
        return self.input_processor.renderer
372

373
374
375
376
377
378
379
380
381
382
383
384
385
386
    def do_log_stats(self) -> None:
        """Log stats if logging is enabled."""
        if self.logger_manager:
            self.logger_manager.log()

    def do_log_stats_with_interval(self) -> None:
        """Log stats when the time interval has passed."""
        now = time.time()
        if not hasattr(self, "_last_log_time"):
            self._last_log_time = now
        if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
            self.do_log_stats()
            self._last_log_time = now

387
388
389
390
391
392
393
394
    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

395
    def list_loras(self) -> set[int]:
396
397
398
399
400
401
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
402

403
404
    def collective_rpc(
        self,
405
406
        method: str | Callable[[WorkerBase], _R],
        timeout: float | None = None,
407
        args: tuple = (),
408
        kwargs: dict[str, Any] | None = None,
409
    ) -> list[_R]:
410
411
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

412
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
413
        return self.collective_rpc("apply_model", args=(func,))
414

415
    def __del__(self):
416
417
        dp_group = getattr(self, "dp_group", None)
        if dp_group is not None and not self.external_launcher_dp:
418
            stateless_destroy_torch_distributed_process_group(dp_group)