llm_engine.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Callable, Mapping
6
from copy import copy
7
from typing import Any
8

9
import torch.nn as nn
10
11
from typing_extensions import TypeVar

12
import vllm.envs as envs
13
from vllm.config import ParallelConfig, VllmConfig
14
from vllm.distributed import stateless_destroy_torch_distributed_process_group
15
from vllm.distributed.parallel_state import get_dp_group
16
from vllm.engine.arg_utils import EngineArgs
17
from vllm.engine.protocol import Device
18
from vllm.inputs import PromptType
19
20
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
21
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
22
from vllm.outputs import PoolingRequestOutput, RequestOutput
23
from vllm.plugins.io_processors import get_io_processor
24
from vllm.pooling_params import PoolingParams
25
from vllm.sampling_params import SamplingParams
26
from vllm.tasks import SupportedTask
27
from vllm.tracing import init_tracer
28
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
29
from vllm.usage.usage_lib import UsageContext
30
from vllm.v1.engine import EngineCoreRequest
31
from vllm.v1.engine.core_client import EngineCoreClient
32
from vllm.v1.engine.output_processor import OutputProcessor
33
from vllm.v1.engine.parallel_sampling import ParentRequest
34
from vllm.v1.engine.processor import Processor
35
from vllm.v1.executor import Executor
36
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
37
38
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
39
from vllm.v1.worker.worker_base import WorkerBase
40
41
42

logger = init_logger(__name__)

43
_R = TypeVar("_R", default=Any)
44

45
46

class LLMEngine:
47
    """Legacy LLMEngine for backwards compatibility."""
48
49
50

    def __init__(
        self,
51
        vllm_config: VllmConfig,
52
        executor_class: type[Executor],
53
        log_stats: bool,
54
        aggregate_engine_logging: bool = False,
55
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
56
        stat_loggers: list[StatLoggerFactory] | None = None,
57
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
58
        use_cached_outputs: bool = False,
59
        multiprocess_mode: bool = False,
60
    ) -> None:
61
62
63
64
65
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
66
67
                "VLLM_USE_V1=0 or 1 and report this issue on Github."
            )
68

69
70
71
        if stat_loggers is not None:
            raise NotImplementedError(
                "Passing StatLoggers to LLMEngine in V1 is not yet supported. "
72
73
                "Set VLLM_USE_V1=0 and file and issue on Github."
            )
74

75
        self.vllm_config = vllm_config
76
        self.observability_config = vllm_config.observability_config
77
        self.model_config = vllm_config.model_config
78
        self.cache_config = vllm_config.cache_config
79

80
81
        self.log_stats = log_stats

82
        executor_backend = self.vllm_config.parallel_config.distributed_executor_backend
83
        parallel_config = vllm_config.parallel_config
84
85
86
87
        self.external_launcher_dp = (
            parallel_config.data_parallel_size > 1
            and executor_backend == "external_launcher"
        )
88
        # important: init dp group before init the engine_core
89
        # In the decoupled engine case this is handled in EngineCoreProc.
90
91
92
93
94
        if (
            not multiprocess_mode
            and parallel_config.data_parallel_size > 1
            and not self.external_launcher_dp
        ):
95
96
97
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
98
99
        self.should_execute_dummy_batch = False

100
101
102
103
104
105
106
107
108
109
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = init_tokenizer_from_configs(self.model_config)

        self.processor = Processor(self.vllm_config, tokenizer)
        self.io_processor = get_io_processor(
            self.vllm_config,
            self.model_config.io_processor_plugin,
        )
110

111
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
112
113
114
        self.output_processor = OutputProcessor(
            self.tokenizer, log_stats=self.log_stats
        )
115
116
        if self.observability_config.otlp_traces_endpoint is not None:
            tracer = init_tracer(
117
118
                "vllm.llm_engine", self.observability_config.otlp_traces_endpoint
            )
119
            self.output_processor.tracer = tracer
120
121
122
123
124

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
125
126
            vllm_config=vllm_config,
            executor_class=executor_class,
127
            log_stats=self.log_stats,
128
        )
129

130
        self.logger_manager: StatLoggerManager | None = None
131
132
133
134
135
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
                custom_stat_loggers=stat_loggers,
                enable_default_loggers=log_stats,
136
                aggregate_engine_logging=aggregate_engine_logging,
137
138
139
            )
            self.logger_manager.log_engine_initialized()

140
141
142
143
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

144
145
146
147
148
        if self.external_launcher_dp:
            # If we use DP in external launcher mode, we reuse the
            # existing DP group used for data communication.
            self.dp_group = get_dp_group().cpu_group

149
150
151
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

152
153
154
155
156
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
157
        stat_loggers: list[StatLoggerFactory] | None = None,
158
159
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
160
161
162
163
164
165
166
167
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING,
        )
168

169
170
171
172
173
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
174
        stat_loggers: list[StatLoggerFactory] | None = None,
175
        enable_multiprocessing: bool = False,
176
177
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
178

179
        # Create the engine configs.
180
        vllm_config = engine_args.create_engine_config(usage_context)
181
        executor_class = Executor.get_class(vllm_config)
182

183
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
184
185
186
187
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
188
189
190
191
192
193
194
195
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=enable_multiprocessing,
        )
196
197

    def get_num_unfinished_requests(self) -> int:
198
        return self.output_processor.get_num_unfinished_requests()
199
200

    def has_unfinished_requests(self) -> bool:
201
        has_unfinished = self.output_processor.has_unfinished_requests()
202
        if self.dp_group is None:
203
            return has_unfinished or self.engine_core.dp_engines_running()
204
205
206
207
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
208
209
            self.dp_group, has_unfinished
        )
210
211
212
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
213
214
215
216
217

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

218
219
220
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.engine_core.get_supported_tasks()

221
    def abort_request(self, request_ids: list[str]) -> None:
222
223
        """Remove request_ids from EngineCore and Detokenizer."""

224
        request_ids = self.output_processor.abort_requests(request_ids)
225
226
        self.engine_core.abort_requests(request_ids)

227
228
229
    def add_request(
        self,
        request_id: str,
230
231
232
233
234
235
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
236
        priority: int = 0,
237
        prompt_text: str | None = None,
238
    ) -> None:
239
240
        # Validate the request_id type.
        if not isinstance(request_id, str):
241
            raise TypeError(f"request_id must be a string, got {type(request_id)}")
242

243
        # Process raw inputs into the request.
244
245
246
247
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
        else:
            assert prompt_text is None
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
            logger.warning_once(
                "Processor has been moved under LLM and will "
                "be removed from LLMEngine in v0.13."
            )
            request = self.processor.process_inputs(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
            )
            prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
263

264
        n = params.n if isinstance(params, SamplingParams) else 1
265

266
267
        if n == 1:
            # Make a new RequestState and queue.
268
            self.output_processor.add_request(request, prompt_text, None, 0)
269
            # Add the request to EngineCore.
270
            self.engine_core.add_request(request)
271
272
273
274
275
276
277
278
279
280
281
            return

        # Fan out child requests (for n>1).
        parent_req = ParentRequest(request_id, params)
        for idx in range(n):
            request_id, params = parent_req.get_child_info(idx)
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
            child_request.sampling_params = params

            # Make a new RequestState and queue.
282
283
284
            self.output_processor.add_request(
                child_request, prompt_text, parent_req, idx
            )
285
286
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
287

288
    def step(self) -> list[RequestOutput] | list[PoolingRequestOutput]:
289
290
291
292
293
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

294
        # 1) Get EngineCoreOutput from the EngineCore.
295
        outputs = self.engine_core.get_output()
296

297
        # 2) Process EngineCoreOutputs.
298
        iteration_stats = IterationStats() if self.log_stats else None
299
        processed_outputs = self.output_processor.process_outputs(
300
301
            outputs.outputs,
            engine_core_timestamp=outputs.timestamp,
302
303
            iteration_stats=iteration_stats,
        )
304

305
306
        # 3) Abort any reqs that finished due to stop strings.
        self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
307

308
        # 4) Record stats
309
        if self.logger_manager is not None and outputs.scheduler_stats is not None:
310
311
312
            self.logger_manager.record(
                scheduler_stats=outputs.scheduler_stats,
                iteration_stats=iteration_stats,
313
                mm_cache_stats=self.processor.stat_mm_cache(),
314
315
            )
            self.do_log_stats_with_interval()
316

317
        return processed_outputs.request_outputs
318

319
    def start_profile(self):
320
        self.engine_core.profile(True)
321

322
    def stop_profile(self):
323
        self.engine_core.profile(False)
324

325
    def reset_mm_cache(self):
326
        self.processor.clear_mm_cache()
327
328
        self.engine_core.reset_mm_cache()

329
    def reset_prefix_cache(self, device: Device | None = None):
330
331
        self.engine_core.reset_prefix_cache()

332
333
334
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

335
    def wake_up(self, tags: list[str] | None = None):
336
        self.engine_core.wake_up(tags)
337

338
339
340
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

341
342
343
344
    def get_metrics(self) -> list[Metric]:
        assert self.log_stats, "Stat logging disabled"
        return get_metrics_snapshot()

345
    @property
346
    def tokenizer(self) -> AnyTokenizer | None:
347
348
349
        return self.processor.tokenizer

    @tokenizer.setter
350
    def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
351
352
        self.processor.tokenizer = tokenizer

353
    def get_tokenizer(self) -> AnyTokenizer:
354
        if self.tokenizer is None:
355
356
357
            raise ValueError(
                "Unable to get tokenizer because skip_tokenizer_init is True"
            )
358

359
        return self.tokenizer
360

361
362
363
364
365
366
367
368
369
370
371
372
373
374
    def do_log_stats(self) -> None:
        """Log stats if logging is enabled."""
        if self.logger_manager:
            self.logger_manager.log()

    def do_log_stats_with_interval(self) -> None:
        """Log stats when the time interval has passed."""
        now = time.time()
        if not hasattr(self, "_last_log_time"):
            self._last_log_time = now
        if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
            self.do_log_stats()
            self._last_log_time = now

375
376
377
378
379
380
381
382
    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

383
    def list_loras(self) -> set[int]:
384
385
386
387
388
389
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
390

391
392
    def collective_rpc(
        self,
393
394
        method: str | Callable[[WorkerBase], _R],
        timeout: float | None = None,
395
        args: tuple = (),
396
        kwargs: dict[str, Any] | None = None,
397
    ) -> list[_R]:
398
399
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

400
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
401
        return self.collective_rpc("apply_model", args=(func,))
402

403
    def __del__(self):
404
405
406
407
        if (
            dp_group := getattr(self, "dp_group", None)
            and not self.external_launcher_dp
        ):
408
            stateless_destroy_torch_distributed_process_group(dp_group)