llm_engine.py 15.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Callable, Mapping
6
from copy import copy
7
from typing import Any
8

9
import torch.nn as nn
10
from typing_extensions import TypeVar
11

12
import vllm.envs as envs
13
from vllm.config import ParallelConfig, VllmConfig
14
from vllm.distributed import stateless_destroy_torch_distributed_process_group
15
from vllm.distributed.parallel_state import get_dp_group
16
from vllm.engine.arg_utils import EngineArgs
17
from vllm.inputs import PromptType
18
19
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
20
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
21
from vllm.outputs import PoolingRequestOutput, RequestOutput
22
from vllm.plugins.io_processors import get_io_processor
23
from vllm.pooling_params import PoolingParams
24
from vllm.renderers import BaseRenderer
25
from vllm.sampling_params import SamplingParams
26
from vllm.tasks import SupportedTask
27
from vllm.tokenizers import TokenizerLike
28
from vllm.tracing import init_tracer
29
from vllm.usage.usage_lib import UsageContext
30
from vllm.v1.engine import EngineCoreRequest
31
from vllm.v1.engine.core_client import EngineCoreClient
32
from vllm.v1.engine.input_processor import InputProcessor
33
from vllm.v1.engine.output_processor import OutputProcessor
34
from vllm.v1.engine.parallel_sampling import ParentRequest
35
from vllm.v1.engine.utils import get_prompt_text
36
from vllm.v1.executor import Executor
37
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
38
39
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
40
from vllm.v1.utils import record_function_or_nullcontext
41
from vllm.v1.worker.worker_base import WorkerBase
42
43
44

logger = init_logger(__name__)

45
_R = TypeVar("_R", default=Any)
46

47
48

class LLMEngine:
49
    """Legacy LLMEngine for backwards compatibility."""
50
51
52

    def __init__(
        self,
53
        vllm_config: VllmConfig,
54
        executor_class: type[Executor],
55
        log_stats: bool,
56
        aggregate_engine_logging: bool = False,
57
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
58
        stat_loggers: list[StatLoggerFactory] | None = None,
59
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
60
        use_cached_outputs: bool = False,
61
        multiprocess_mode: bool = False,
62
    ) -> None:
63
        self.vllm_config = vllm_config
64
        self.observability_config = vllm_config.observability_config
65
        self.model_config = vllm_config.model_config
66
        self.cache_config = vllm_config.cache_config
67

68
69
        self.log_stats = log_stats

70
        parallel_config = vllm_config.parallel_config
71
72
        executor_backend = parallel_config.distributed_executor_backend

73
74
75
76
        self.external_launcher_dp = (
            parallel_config.data_parallel_size > 1
            and executor_backend == "external_launcher"
        )
77
        # important: init dp group before init the engine_core
78
        # In the decoupled engine case this is handled in EngineCoreProc.
79
80
81
82
83
        if (
            not multiprocess_mode
            and parallel_config.data_parallel_size > 1
            and not self.external_launcher_dp
        ):
84
85
86
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
87
88
        self.should_execute_dummy_batch = False

89
        self.input_processor = InputProcessor(self.vllm_config)
90
91
        self.io_processor = get_io_processor(
            self.vllm_config,
92
            self.model_config.io_processor_plugin,
93
        )
94

95
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
96
        self.output_processor = OutputProcessor(
97
98
99
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
100
        )
101
102
103
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
104
            self.output_processor.tracer = tracer
105
106
107
108
109

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
110
111
            vllm_config=vllm_config,
            executor_class=executor_class,
112
            log_stats=self.log_stats,
113
        )
114

115
        self.logger_manager: StatLoggerManager | None = None
116
117
118
119
120
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
                custom_stat_loggers=stat_loggers,
                enable_default_loggers=log_stats,
121
                aggregate_engine_logging=aggregate_engine_logging,
122
123
124
            )
            self.logger_manager.log_engine_initialized()

125
126
127
128
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

129
130
131
132
133
        if self.external_launcher_dp:
            # If we use DP in external launcher mode, we reuse the
            # existing DP group used for data communication.
            self.dp_group = get_dp_group().cpu_group

134
135
136
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

137
138
139
140
141
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
142
        stat_loggers: list[StatLoggerFactory] | None = None,
143
144
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
145
146
147
148
149
150
151
152
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING,
        )
153

154
155
156
157
158
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
159
        stat_loggers: list[StatLoggerFactory] | None = None,
160
        enable_multiprocessing: bool = False,
161
162
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
163

164
        # Create the engine configs.
165
        vllm_config = engine_args.create_engine_config(usage_context)
166
        executor_class = Executor.get_class(vllm_config)
167

168
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
169
170
171
172
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
173
174
175
176
177
178
179
180
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=enable_multiprocessing,
        )
181
182

    def get_num_unfinished_requests(self) -> int:
183
        return self.output_processor.get_num_unfinished_requests()
184
185

    def has_unfinished_requests(self) -> bool:
186
        has_unfinished = self.output_processor.has_unfinished_requests()
187
        if self.dp_group is None:
188
            return has_unfinished or self.engine_core.dp_engines_running()
189
190
191
192
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
193
194
            self.dp_group, has_unfinished
        )
195
196
197
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
198
199
200
201
202

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

203
204
205
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.engine_core.get_supported_tasks()

206
    def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
207
208
        """Remove request_ids from EngineCore and Detokenizer."""

209
        request_ids = self.output_processor.abort_requests(request_ids, internal)
210
211
        self.engine_core.abort_requests(request_ids)

212
213
214
    def add_request(
        self,
        request_id: str,
215
216
217
218
219
220
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
221
        priority: int = 0,
222
        prompt_text: str | None = None,
223
    ) -> None:
224
225
        # Validate the request_id type.
        if not isinstance(request_id, str):
226
            raise TypeError(f"request_id must be a string, got {type(request_id)}")
227

228
        # Process raw inputs into the request.
229
230
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
231
232
233
234
235
236
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
237
238
        else:
            assert prompt_text is None
239
            request = self.input_processor.process_inputs(
240
241
242
243
244
245
246
247
248
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
            )
249
            prompt_text = get_prompt_text(prompt)
250

251
252
        self.input_processor.assign_request_id(request)

253
254
255
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

256
        n = params.n if isinstance(params, SamplingParams) else 1
257

258
259
        if n == 1:
            # Make a new RequestState and queue.
260
            self.output_processor.add_request(request, prompt_text, None, 0)
261
            # Add the request to EngineCore.
262
            self.engine_core.add_request(request)
263
264
265
            return

        # Fan out child requests (for n>1).
266
        parent_req = ParentRequest(request)
267
        for idx in range(n):
268
            request_id, child_params = parent_req.get_child_info(idx)
269
270
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
271
            child_request.sampling_params = child_params
272
273

            # Make a new RequestState and queue.
274
275
276
            self.output_processor.add_request(
                child_request, prompt_text, parent_req, idx
            )
277
278
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
279

280
    def step(self) -> list[RequestOutput | PoolingRequestOutput]:
281
282
283
284
285
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

286
        # 1) Get EngineCoreOutput from the EngineCore.
287
        with record_function_or_nullcontext("llm_engine step: get_output"):
288
            outputs = self.engine_core.get_output()
289

290
        # 2) Process EngineCoreOutputs.
291
        with record_function_or_nullcontext("llm_engine step: process_outputs"):
292
293
294
295
296
297
298
            iteration_stats = IterationStats() if self.log_stats else None
            processed_outputs = self.output_processor.process_outputs(
                outputs.outputs,
                engine_core_timestamp=outputs.timestamp,
                iteration_stats=iteration_stats,
            )
            self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
299

300
        # 3) Abort any reqs that finished due to stop strings.
301
        with record_function_or_nullcontext("llm_engine step: abort_requests"):
302
            self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
303

304
        # 4) Record stats
305
        with record_function_or_nullcontext("llm_engine step: record_stats"):
306
307
308
309
            if self.logger_manager is not None and outputs.scheduler_stats is not None:
                self.logger_manager.record(
                    scheduler_stats=outputs.scheduler_stats,
                    iteration_stats=iteration_stats,
310
                    mm_cache_stats=self.input_processor.stat_mm_cache(),
311
312
                )
                self.do_log_stats_with_interval()
313

314
        return processed_outputs.request_outputs
315

316
    def start_profile(self):
317
        self.engine_core.profile(True)
318

319
    def stop_profile(self):
320
        self.engine_core.profile(False)
321

322
    def reset_mm_cache(self):
323
        self.input_processor.clear_mm_cache()
324
325
        self.engine_core.reset_mm_cache()

326
327
328
329
330
331
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.engine_core.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
332

333
334
335
336
337
338
339
340
    def reset_encoder_cache(self) -> None:
        """Reset the encoder cache to invalidate all cached encoder outputs.

        This should be called when model weights are updated to ensure
        stale vision embeddings computed with old weights are not reused.
        """
        self.engine_core.reset_encoder_cache()

341
342
343
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

344
345
346
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

347
    def wake_up(self, tags: list[str] | None = None):
348
        self.engine_core.wake_up(tags)
349

350
351
352
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

353
354
355
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

356
357
358
359
    def get_metrics(self) -> list[Metric]:
        assert self.log_stats, "Stat logging disabled"
        return get_metrics_snapshot()

360
    @property
361
    def tokenizer(self) -> TokenizerLike | None:
362
        return self.input_processor.tokenizer
363

364
    def get_tokenizer(self) -> TokenizerLike:
365
        return self.input_processor.get_tokenizer()
366

367
    @property
368
    def renderer(self) -> BaseRenderer:
369
        return self.input_processor.renderer
370

371
372
373
374
375
376
377
378
379
380
381
382
383
384
    def do_log_stats(self) -> None:
        """Log stats if logging is enabled."""
        if self.logger_manager:
            self.logger_manager.log()

    def do_log_stats_with_interval(self) -> None:
        """Log stats when the time interval has passed."""
        now = time.time()
        if not hasattr(self, "_last_log_time"):
            self._last_log_time = now
        if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
            self.do_log_stats()
            self._last_log_time = now

385
386
387
388
389
390
391
392
    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

393
    def list_loras(self) -> set[int]:
394
395
396
397
398
399
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
400

401
402
    def collective_rpc(
        self,
403
404
        method: str | Callable[[WorkerBase], _R],
        timeout: float | None = None,
405
        args: tuple = (),
406
        kwargs: dict[str, Any] | None = None,
407
    ) -> list[_R]:
408
409
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

410
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
411
        return self.collective_rpc("apply_model", args=(func,))
412

413
    def __del__(self):
414
415
        dp_group = getattr(self, "dp_group", None)
        if dp_group is not None and not self.external_launcher_dp:
416
            stateless_destroy_torch_distributed_process_group(dp_group)