llm_engine.py 15.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Callable, Mapping
6
from copy import copy
7
from typing import Any, cast
8

9
import torch.nn as nn
10
from typing_extensions import TypeVar
11

12
import vllm.envs as envs
13
from vllm.config import ParallelConfig, VllmConfig
14
from vllm.distributed import stateless_destroy_torch_distributed_process_group
15
from vllm.distributed.parallel_state import get_dp_group
16
from vllm.engine.arg_utils import EngineArgs
17
from vllm.inputs import PromptType
18
19
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
20
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
21
from vllm.outputs import PoolingRequestOutput, RequestOutput
22
from vllm.plugins.io_processors import get_io_processor
23
from vllm.pooling_params import PoolingParams
24
from vllm.sampling_params import SamplingParams
25
from vllm.tasks import SupportedTask
26
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
27
from vllm.tracing import init_tracer
28
from vllm.usage.usage_lib import UsageContext
29
from vllm.v1.engine import EngineCoreRequest
30
from vllm.v1.engine.core_client import EngineCoreClient
31
from vllm.v1.engine.input_processor import InputProcessor
32
from vllm.v1.engine.output_processor import OutputProcessor
33
from vllm.v1.engine.parallel_sampling import ParentRequest
34
from vllm.v1.executor import Executor
35
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
36
37
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
38
from vllm.v1.utils import record_function_or_nullcontext
39
from vllm.v1.worker.worker_base import WorkerBase
40
41
42

logger = init_logger(__name__)

43
_R = TypeVar("_R", default=Any)
44

45
46

class LLMEngine:
47
    """Legacy LLMEngine for backwards compatibility."""
48
49
50

    def __init__(
        self,
51
        vllm_config: VllmConfig,
52
        executor_class: type[Executor],
53
        log_stats: bool,
54
        aggregate_engine_logging: bool = False,
55
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
56
        stat_loggers: list[StatLoggerFactory] | None = None,
57
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
58
        use_cached_outputs: bool = False,
59
        multiprocess_mode: bool = False,
60
    ) -> None:
61
        self.vllm_config = vllm_config
62
        self.observability_config = vllm_config.observability_config
63
        self.model_config = vllm_config.model_config
64
        self.cache_config = vllm_config.cache_config
65

66
67
        self.log_stats = log_stats

68
        parallel_config = vllm_config.parallel_config
69
70
        executor_backend = parallel_config.distributed_executor_backend

71
72
73
74
        self.external_launcher_dp = (
            parallel_config.data_parallel_size > 1
            and executor_backend == "external_launcher"
        )
75
        # important: init dp group before init the engine_core
76
        # In the decoupled engine case this is handled in EngineCoreProc.
77
78
79
80
81
        if (
            not multiprocess_mode
            and parallel_config.data_parallel_size > 1
            and not self.external_launcher_dp
        ):
82
83
84
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
85
86
        self.should_execute_dummy_batch = False

87
        tokenizer = cached_tokenizer_from_config(self.model_config)
88

89
        self.input_processor = InputProcessor(self.vllm_config, tokenizer)
90
91
        self.io_processor = get_io_processor(
            self.vllm_config,
92
            self.model_config.io_processor_plugin,
93
        )
94

95
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
96
        self.output_processor = OutputProcessor(
97
98
99
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
100
        )
101
102
103
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
104
            self.output_processor.tracer = tracer
105
106
107
108
109

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
110
111
            vllm_config=vllm_config,
            executor_class=executor_class,
112
            log_stats=self.log_stats,
113
        )
114

115
        self.logger_manager: StatLoggerManager | None = None
116
117
118
119
120
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
                custom_stat_loggers=stat_loggers,
                enable_default_loggers=log_stats,
121
                aggregate_engine_logging=aggregate_engine_logging,
122
123
124
            )
            self.logger_manager.log_engine_initialized()

125
126
127
128
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

129
130
131
132
133
        if self.external_launcher_dp:
            # If we use DP in external launcher mode, we reuse the
            # existing DP group used for data communication.
            self.dp_group = get_dp_group().cpu_group

134
135
136
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

137
138
139
140
141
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
142
        stat_loggers: list[StatLoggerFactory] | None = None,
143
144
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
145
146
147
148
149
150
151
152
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING,
        )
153

154
155
156
157
158
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
159
        stat_loggers: list[StatLoggerFactory] | None = None,
160
        enable_multiprocessing: bool = False,
161
162
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
163

164
        # Create the engine configs.
165
        vllm_config = engine_args.create_engine_config(usage_context)
166
        executor_class = Executor.get_class(vllm_config)
167

168
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
169
170
171
172
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
173
174
175
176
177
178
179
180
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=enable_multiprocessing,
        )
181
182

    def get_num_unfinished_requests(self) -> int:
183
        return self.output_processor.get_num_unfinished_requests()
184
185

    def has_unfinished_requests(self) -> bool:
186
        has_unfinished = self.output_processor.has_unfinished_requests()
187
        if self.dp_group is None:
188
            return has_unfinished or self.engine_core.dp_engines_running()
189
190
191
192
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
193
194
            self.dp_group, has_unfinished
        )
195
196
197
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
198
199
200
201
202

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

203
204
205
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.engine_core.get_supported_tasks()

206
    def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
207
208
        """Remove request_ids from EngineCore and Detokenizer."""

209
        request_ids = self.output_processor.abort_requests(request_ids, internal)
210
211
        self.engine_core.abort_requests(request_ids)

212
213
214
    def add_request(
        self,
        request_id: str,
215
216
217
218
219
220
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
221
        priority: int = 0,
222
        prompt_text: str | None = None,
223
    ) -> None:
224
225
        # Validate the request_id type.
        if not isinstance(request_id, str):
226
            raise TypeError(f"request_id must be a string, got {type(request_id)}")
227

228
        # Process raw inputs into the request.
229
230
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
231
232
233
234
235
236
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
237
238
        else:
            assert prompt_text is None
239
            request = self.input_processor.process_inputs(
240
241
242
243
244
245
246
247
248
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
            )
249
250
251
252
            if isinstance(prompt, str):
                prompt_text = prompt
            elif isinstance(prompt, Mapping):
                prompt_text = cast(str | None, prompt.get("prompt"))
253

254
255
        self.input_processor.assign_request_id(request)

256
257
258
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

259
        n = params.n if isinstance(params, SamplingParams) else 1
260

261
262
        if n == 1:
            # Make a new RequestState and queue.
263
            self.output_processor.add_request(request, prompt_text, None, 0)
264
            # Add the request to EngineCore.
265
            self.engine_core.add_request(request)
266
267
268
            return

        # Fan out child requests (for n>1).
269
        parent_req = ParentRequest(request)
270
        for idx in range(n):
271
            request_id, child_params = parent_req.get_child_info(idx)
272
273
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
274
            child_request.sampling_params = child_params
275
276

            # Make a new RequestState and queue.
277
278
279
            self.output_processor.add_request(
                child_request, prompt_text, parent_req, idx
            )
280
281
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
282

283
    def step(self) -> list[RequestOutput | PoolingRequestOutput]:
284
285
286
287
288
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

289
        # 1) Get EngineCoreOutput from the EngineCore.
290
        with record_function_or_nullcontext("llm_engine step: get_output"):
291
            outputs = self.engine_core.get_output()
292

293
        # 2) Process EngineCoreOutputs.
294
        with record_function_or_nullcontext("llm_engine step: process_outputs"):
295
296
297
298
299
300
301
            iteration_stats = IterationStats() if self.log_stats else None
            processed_outputs = self.output_processor.process_outputs(
                outputs.outputs,
                engine_core_timestamp=outputs.timestamp,
                iteration_stats=iteration_stats,
            )
            self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
302

303
        # 3) Abort any reqs that finished due to stop strings.
304
        with record_function_or_nullcontext("llm_engine step: abort_requests"):
305
            self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
306

307
        # 4) Record stats
308
        with record_function_or_nullcontext("llm_engine step: record_stats"):
309
310
311
312
            if self.logger_manager is not None and outputs.scheduler_stats is not None:
                self.logger_manager.record(
                    scheduler_stats=outputs.scheduler_stats,
                    iteration_stats=iteration_stats,
313
                    mm_cache_stats=self.input_processor.stat_mm_cache(),
314
315
                )
                self.do_log_stats_with_interval()
316

317
        return processed_outputs.request_outputs
318

319
    def start_profile(self):
320
        self.engine_core.profile(True)
321

322
    def stop_profile(self):
323
        self.engine_core.profile(False)
324

325
    def reset_mm_cache(self):
326
        self.input_processor.clear_mm_cache()
327
328
        self.engine_core.reset_mm_cache()

329
330
331
332
333
334
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.engine_core.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
335

336
337
338
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

339
340
341
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

342
    def wake_up(self, tags: list[str] | None = None):
343
        self.engine_core.wake_up(tags)
344

345
346
347
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

348
349
350
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

351
352
353
354
    def get_metrics(self) -> list[Metric]:
        assert self.log_stats, "Stat logging disabled"
        return get_metrics_snapshot()

355
    @property
356
    def tokenizer(self) -> TokenizerLike | None:
357
        return self.input_processor.tokenizer
358

359
    def get_tokenizer(self) -> TokenizerLike:
360
        if self.tokenizer is None:
361
            raise ValueError(
362
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
363
            )
364

365
        return self.tokenizer
366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
    def do_log_stats(self) -> None:
        """Log stats if logging is enabled."""
        if self.logger_manager:
            self.logger_manager.log()

    def do_log_stats_with_interval(self) -> None:
        """Log stats when the time interval has passed."""
        now = time.time()
        if not hasattr(self, "_last_log_time"):
            self._last_log_time = now
        if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
            self.do_log_stats()
            self._last_log_time = now

381
382
383
384
385
386
387
388
    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

389
    def list_loras(self) -> set[int]:
390
391
392
393
394
395
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
396

397
398
    def collective_rpc(
        self,
399
400
        method: str | Callable[[WorkerBase], _R],
        timeout: float | None = None,
401
        args: tuple = (),
402
        kwargs: dict[str, Any] | None = None,
403
    ) -> list[_R]:
404
405
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

406
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
407
        return self.collective_rpc("apply_model", args=(func,))
408

409
    def __del__(self):
410
411
        dp_group = getattr(self, "dp_group", None)
        if dp_group is not None and not self.external_launcher_dp:
412
            stateless_destroy_torch_distributed_process_group(dp_group)