llm_engine.py 14.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Callable, Mapping
6
from copy import copy
7
from typing import Any, cast
8

9
import torch.nn as nn
10
11
from typing_extensions import TypeVar

12
import vllm.envs as envs
13
from vllm.config import ParallelConfig, VllmConfig
14
from vllm.distributed import stateless_destroy_torch_distributed_process_group
15
from vllm.distributed.parallel_state import get_dp_group
16
from vllm.engine.arg_utils import EngineArgs
17
from vllm.engine.protocol import Device
18
from vllm.inputs import PromptType
19
20
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
21
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
22
from vllm.outputs import PoolingRequestOutput, RequestOutput
23
from vllm.plugins.io_processors import get_io_processor
24
from vllm.pooling_params import PoolingParams
25
from vllm.sampling_params import SamplingParams
26
from vllm.tasks import SupportedTask
27
from vllm.tracing import init_tracer
28
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
29
from vllm.usage.usage_lib import UsageContext
30
from vllm.v1.engine import EngineCoreRequest
31
from vllm.v1.engine.core_client import EngineCoreClient
32
from vllm.v1.engine.output_processor import OutputProcessor
33
from vllm.v1.engine.parallel_sampling import ParentRequest
34
from vllm.v1.engine.processor import Processor
35
from vllm.v1.executor import Executor
36
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
37
38
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
39
from vllm.v1.worker.worker_base import WorkerBase
40
41
42

logger = init_logger(__name__)

43
_R = TypeVar("_R", default=Any)
44

45
46

class LLMEngine:
47
    """Legacy LLMEngine for backwards compatibility."""
48
49
50

    def __init__(
        self,
51
        vllm_config: VllmConfig,
52
        executor_class: type[Executor],
53
        log_stats: bool,
54
        aggregate_engine_logging: bool = False,
55
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
56
        stat_loggers: list[StatLoggerFactory] | None = None,
57
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
58
        use_cached_outputs: bool = False,
59
        multiprocess_mode: bool = False,
60
    ) -> None:
61
        self.vllm_config = vllm_config
62
        self.observability_config = vllm_config.observability_config
63
        self.model_config = vllm_config.model_config
64
        self.cache_config = vllm_config.cache_config
65

66
67
        self.log_stats = log_stats

68
        executor_backend = self.vllm_config.parallel_config.distributed_executor_backend
69
        parallel_config = vllm_config.parallel_config
70
71
72
73
        self.external_launcher_dp = (
            parallel_config.data_parallel_size > 1
            and executor_backend == "external_launcher"
        )
74
        # important: init dp group before init the engine_core
75
        # In the decoupled engine case this is handled in EngineCoreProc.
76
77
78
79
80
        if (
            not multiprocess_mode
            and parallel_config.data_parallel_size > 1
            and not self.external_launcher_dp
        ):
81
82
83
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
84
85
        self.should_execute_dummy_batch = False

86
87
88
89
90
91
92
93
94
95
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = init_tokenizer_from_configs(self.model_config)

        self.processor = Processor(self.vllm_config, tokenizer)
        self.io_processor = get_io_processor(
            self.vllm_config,
            self.model_config.io_processor_plugin,
        )
96

97
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
98
99
100
        self.output_processor = OutputProcessor(
            self.tokenizer, log_stats=self.log_stats
        )
101
102
103
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
104
            self.output_processor.tracer = tracer
105
106
107
108
109

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
110
111
            vllm_config=vllm_config,
            executor_class=executor_class,
112
            log_stats=self.log_stats,
113
        )
114

115
        self.logger_manager: StatLoggerManager | None = None
116
117
118
119
120
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
                custom_stat_loggers=stat_loggers,
                enable_default_loggers=log_stats,
121
                aggregate_engine_logging=aggregate_engine_logging,
122
123
124
            )
            self.logger_manager.log_engine_initialized()

125
126
127
128
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

129
130
131
132
133
        if self.external_launcher_dp:
            # If we use DP in external launcher mode, we reuse the
            # existing DP group used for data communication.
            self.dp_group = get_dp_group().cpu_group

134
135
136
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

137
138
139
140
141
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
142
        stat_loggers: list[StatLoggerFactory] | None = None,
143
144
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
145
146
147
148
149
150
151
152
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING,
        )
153

154
155
156
157
158
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
159
        stat_loggers: list[StatLoggerFactory] | None = None,
160
        enable_multiprocessing: bool = False,
161
162
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
163

164
        # Create the engine configs.
165
        vllm_config = engine_args.create_engine_config(usage_context)
166
        executor_class = Executor.get_class(vllm_config)
167

168
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
169
170
171
172
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
173
174
175
176
177
178
179
180
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=enable_multiprocessing,
        )
181
182

    def get_num_unfinished_requests(self) -> int:
183
        return self.output_processor.get_num_unfinished_requests()
184
185

    def has_unfinished_requests(self) -> bool:
186
        has_unfinished = self.output_processor.has_unfinished_requests()
187
        if self.dp_group is None:
188
            return has_unfinished or self.engine_core.dp_engines_running()
189
190
191
192
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
193
194
            self.dp_group, has_unfinished
        )
195
196
197
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
198
199
200
201
202

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

203
204
205
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.engine_core.get_supported_tasks()

206
    def abort_request(self, request_ids: list[str]) -> None:
207
208
        """Remove request_ids from EngineCore and Detokenizer."""

209
        request_ids = self.output_processor.abort_requests(request_ids)
210
211
        self.engine_core.abort_requests(request_ids)

212
213
214
    def add_request(
        self,
        request_id: str,
215
216
217
218
219
220
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
221
        priority: int = 0,
222
        prompt_text: str | None = None,
223
    ) -> None:
224
225
        # Validate the request_id type.
        if not isinstance(request_id, str):
226
            raise TypeError(f"request_id must be a string, got {type(request_id)}")
227

228
        # Process raw inputs into the request.
229
230
231
232
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
        else:
            assert prompt_text is None
233
234
235
236
237
238
239
240
241
242
243
244
245
246
            logger.warning_once(
                "Processor has been moved under LLM and will "
                "be removed from LLMEngine in v0.13."
            )
            request = self.processor.process_inputs(
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
            )
247
248
249
250
            if isinstance(prompt, str):
                prompt_text = prompt
            elif isinstance(prompt, Mapping):
                prompt_text = cast(str | None, prompt.get("prompt"))
251

252
        n = params.n if isinstance(params, SamplingParams) else 1
253

254
255
        if n == 1:
            # Make a new RequestState and queue.
256
            self.output_processor.add_request(request, prompt_text, None, 0)
257
            # Add the request to EngineCore.
258
            self.engine_core.add_request(request)
259
260
261
262
263
264
265
266
267
268
269
            return

        # Fan out child requests (for n>1).
        parent_req = ParentRequest(request_id, params)
        for idx in range(n):
            request_id, params = parent_req.get_child_info(idx)
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
            child_request.sampling_params = params

            # Make a new RequestState and queue.
270
271
272
            self.output_processor.add_request(
                child_request, prompt_text, parent_req, idx
            )
273
274
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
275

276
    def step(self) -> list[RequestOutput | PoolingRequestOutput]:
277
278
279
280
281
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

282
        # 1) Get EngineCoreOutput from the EngineCore.
283
        outputs = self.engine_core.get_output()
284

285
        # 2) Process EngineCoreOutputs.
286
        iteration_stats = IterationStats() if self.log_stats else None
287
        processed_outputs = self.output_processor.process_outputs(
288
289
            outputs.outputs,
            engine_core_timestamp=outputs.timestamp,
290
291
            iteration_stats=iteration_stats,
        )
292
        self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
293

294
295
        # 3) Abort any reqs that finished due to stop strings.
        self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
296

297
        # 4) Record stats
298
        if self.logger_manager is not None and outputs.scheduler_stats is not None:
299
300
301
            self.logger_manager.record(
                scheduler_stats=outputs.scheduler_stats,
                iteration_stats=iteration_stats,
302
                mm_cache_stats=self.processor.stat_mm_cache(),
303
304
            )
            self.do_log_stats_with_interval()
305

306
        return processed_outputs.request_outputs
307

308
    def start_profile(self):
309
        self.engine_core.profile(True)
310

311
    def stop_profile(self):
312
        self.engine_core.profile(False)
313

314
    def reset_mm_cache(self):
315
        self.processor.clear_mm_cache()
316
317
        self.engine_core.reset_mm_cache()

318
    def reset_prefix_cache(self, device: Device | None = None):
319
320
        self.engine_core.reset_prefix_cache()

321
322
323
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

324
325
326
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

327
    def wake_up(self, tags: list[str] | None = None):
328
        self.engine_core.wake_up(tags)
329

330
331
332
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

333
334
335
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

336
337
338
339
    def get_metrics(self) -> list[Metric]:
        assert self.log_stats, "Stat logging disabled"
        return get_metrics_snapshot()

340
    @property
341
    def tokenizer(self) -> AnyTokenizer | None:
342
343
344
        return self.processor.tokenizer

    @tokenizer.setter
345
    def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
346
347
        self.processor.tokenizer = tokenizer

348
    def get_tokenizer(self) -> AnyTokenizer:
349
        if self.tokenizer is None:
350
351
352
            raise ValueError(
                "Unable to get tokenizer because skip_tokenizer_init is True"
            )
353

354
        return self.tokenizer
355

356
357
358
359
360
361
362
363
364
365
366
367
368
369
    def do_log_stats(self) -> None:
        """Log stats if logging is enabled."""
        if self.logger_manager:
            self.logger_manager.log()

    def do_log_stats_with_interval(self) -> None:
        """Log stats when the time interval has passed."""
        now = time.time()
        if not hasattr(self, "_last_log_time"):
            self._last_log_time = now
        if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
            self.do_log_stats()
            self._last_log_time = now

370
371
372
373
374
375
376
377
    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

378
    def list_loras(self) -> set[int]:
379
380
381
382
383
384
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
385

386
387
    def collective_rpc(
        self,
388
389
        method: str | Callable[[WorkerBase], _R],
        timeout: float | None = None,
390
        args: tuple = (),
391
        kwargs: dict[str, Any] | None = None,
392
    ) -> list[_R]:
393
394
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

395
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
396
        return self.collective_rpc("apply_model", args=(func,))
397

398
    def __del__(self):
399
400
401
402
        if (
            dp_group := getattr(self, "dp_group", None)
            and not self.external_launcher_dp
        ):
403
            stateless_destroy_torch_distributed_process_group(dp_group)