"vllm/vscode:/vscode.git/clone" did not exist on "fc2dbcda8b717e6aac0794cb9b4cc86b78c36504"
llm_engine.py 15.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Callable, Mapping
6
from copy import copy
7
from typing import Any, cast
8

9
import torch.nn as nn
10
from typing_extensions import TypeVar, deprecated
11

12
import vllm.envs as envs
13
from vllm.config import ParallelConfig, VllmConfig
14
from vllm.distributed import stateless_destroy_torch_distributed_process_group
15
from vllm.distributed.parallel_state import get_dp_group
16
from vllm.engine.arg_utils import EngineArgs
17
from vllm.inputs import PromptType
18
19
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
20
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
21
from vllm.outputs import PoolingRequestOutput, RequestOutput
22
from vllm.plugins.io_processors import get_io_processor
23
from vllm.pooling_params import PoolingParams
24
from vllm.sampling_params import SamplingParams
25
from vllm.tasks import SupportedTask
26
from vllm.tokenizers import TokenizerLike
27
from vllm.tracing import init_tracer
28
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
29
from vllm.usage.usage_lib import UsageContext
30
from vllm.v1.engine import EngineCoreRequest
31
from vllm.v1.engine.core_client import EngineCoreClient
32
from vllm.v1.engine.input_processor import InputProcessor
33
from vllm.v1.engine.output_processor import OutputProcessor
34
from vllm.v1.engine.parallel_sampling import ParentRequest
35
from vllm.v1.executor import Executor
36
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
37
38
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
39
from vllm.v1.utils import record_function_or_nullcontext
40
from vllm.v1.worker.worker_base import WorkerBase
41
42
43

logger = init_logger(__name__)

44
_R = TypeVar("_R", default=Any)
45

46
47

class LLMEngine:
48
    """Legacy LLMEngine for backwards compatibility."""
49
50
51

    def __init__(
        self,
52
        vllm_config: VllmConfig,
53
        executor_class: type[Executor],
54
        log_stats: bool,
55
        aggregate_engine_logging: bool = False,
56
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
57
        stat_loggers: list[StatLoggerFactory] | None = None,
58
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
59
        use_cached_outputs: bool = False,
60
        multiprocess_mode: bool = False,
61
    ) -> None:
62
        self.vllm_config = vllm_config
63
        self.observability_config = vllm_config.observability_config
64
        self.model_config = vllm_config.model_config
65
        self.cache_config = vllm_config.cache_config
66

67
68
        self.log_stats = log_stats

69
        executor_backend = self.vllm_config.parallel_config.distributed_executor_backend
70
        parallel_config = vllm_config.parallel_config
71
72
73
74
        self.external_launcher_dp = (
            parallel_config.data_parallel_size > 1
            and executor_backend == "external_launcher"
        )
75
        # important: init dp group before init the engine_core
76
        # In the decoupled engine case this is handled in EngineCoreProc.
77
78
79
80
81
        if (
            not multiprocess_mode
            and parallel_config.data_parallel_size > 1
            and not self.external_launcher_dp
        ):
82
83
84
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
85
86
        self.should_execute_dummy_batch = False

87
88
89
90
91
        if self.model_config.skip_tokenizer_init:
            tokenizer = None
        else:
            tokenizer = init_tokenizer_from_configs(self.model_config)

92
        self.input_processor = InputProcessor(self.vllm_config, tokenizer)
93
94
95
96
        self.io_processor = get_io_processor(
            self.vllm_config,
            self.model_config.io_processor_plugin,
        )
97

98
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
99
        self.output_processor = OutputProcessor(
100
101
102
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
103
        )
104
105
106
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
            tracer = init_tracer("vllm.llm_engine", endpoint)
107
            self.output_processor.tracer = tracer
108
109
110
111
112

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
113
114
            vllm_config=vllm_config,
            executor_class=executor_class,
115
            log_stats=self.log_stats,
116
        )
117

118
        self.logger_manager: StatLoggerManager | None = None
119
120
121
122
123
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
                custom_stat_loggers=stat_loggers,
                enable_default_loggers=log_stats,
124
                aggregate_engine_logging=aggregate_engine_logging,
125
126
127
            )
            self.logger_manager.log_engine_initialized()

128
129
130
131
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

132
133
134
135
136
        if self.external_launcher_dp:
            # If we use DP in external launcher mode, we reuse the
            # existing DP group used for data communication.
            self.dp_group = get_dp_group().cpu_group

137
138
139
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

140
141
142
143
144
145
146
147
    @property
    @deprecated(
        "`LLMEngine.processor` has been renamed to `LLMEngine.input_processor`. "
        "The old name will be removed in v0.13."
    )
    def processor(self):
        return self.input_processor

148
149
150
151
152
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
153
        stat_loggers: list[StatLoggerFactory] | None = None,
154
155
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
156
157
158
159
160
161
162
163
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING,
        )
164

165
166
167
168
169
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
170
        stat_loggers: list[StatLoggerFactory] | None = None,
171
        enable_multiprocessing: bool = False,
172
173
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
174

175
        # Create the engine configs.
176
        vllm_config = engine_args.create_engine_config(usage_context)
177
        executor_class = Executor.get_class(vllm_config)
178

179
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
180
181
182
183
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
184
185
186
187
188
189
190
191
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=enable_multiprocessing,
        )
192
193

    def get_num_unfinished_requests(self) -> int:
194
        return self.output_processor.get_num_unfinished_requests()
195
196

    def has_unfinished_requests(self) -> bool:
197
        has_unfinished = self.output_processor.has_unfinished_requests()
198
        if self.dp_group is None:
199
            return has_unfinished or self.engine_core.dp_engines_running()
200
201
202
203
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
204
205
            self.dp_group, has_unfinished
        )
206
207
208
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
209
210
211
212
213

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

214
215
216
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.engine_core.get_supported_tasks()

217
    def abort_request(self, request_ids: list[str]) -> None:
218
219
        """Remove request_ids from EngineCore and Detokenizer."""

220
        request_ids = self.output_processor.abort_requests(request_ids)
221
222
        self.engine_core.abort_requests(request_ids)

223
224
225
    def add_request(
        self,
        request_id: str,
226
227
228
229
230
231
        prompt: EngineCoreRequest | PromptType,
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
232
        priority: int = 0,
233
        prompt_text: str | None = None,
234
    ) -> None:
235
236
        # Validate the request_id type.
        if not isinstance(request_id, str):
237
            raise TypeError(f"request_id must be a string, got {type(request_id)}")
238

239
        # Process raw inputs into the request.
240
241
242
243
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
        else:
            assert prompt_text is None
244
            request = self.input_processor.process_inputs(
245
246
247
248
249
250
251
252
253
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
            )
254
255
256
257
            if isinstance(prompt, str):
                prompt_text = prompt
            elif isinstance(prompt, Mapping):
                prompt_text = cast(str | None, prompt.get("prompt"))
258

259
260
261
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

262
        n = params.n if isinstance(params, SamplingParams) else 1
263

264
265
        if n == 1:
            # Make a new RequestState and queue.
266
            self.output_processor.add_request(request, prompt_text, None, 0)
267
            # Add the request to EngineCore.
268
            self.engine_core.add_request(request)
269
270
271
272
273
            return

        # Fan out child requests (for n>1).
        parent_req = ParentRequest(request_id, params)
        for idx in range(n):
274
            request_id, child_params = parent_req.get_child_info(idx)
275
276
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
277
            child_request.sampling_params = child_params
278
279

            # Make a new RequestState and queue.
280
281
282
            self.output_processor.add_request(
                child_request, prompt_text, parent_req, idx
            )
283
284
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
285

286
    def step(self) -> list[RequestOutput | PoolingRequestOutput]:
287
288
289
290
291
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

292
        # 1) Get EngineCoreOutput from the EngineCore.
293
        with record_function_or_nullcontext("llm_engine step: get_output"):
294
            outputs = self.engine_core.get_output()
295

296
        # 2) Process EngineCoreOutputs.
297
        with record_function_or_nullcontext("llm_engine step: process_outputs"):
298
299
300
301
302
303
304
            iteration_stats = IterationStats() if self.log_stats else None
            processed_outputs = self.output_processor.process_outputs(
                outputs.outputs,
                engine_core_timestamp=outputs.timestamp,
                iteration_stats=iteration_stats,
            )
            self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
305

306
        # 3) Abort any reqs that finished due to stop strings.
307
        with record_function_or_nullcontext("llm_engine step: abort_requests"):
308
            self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
309

310
        # 4) Record stats
311
        with record_function_or_nullcontext("llm_engine step: record_stats"):
312
313
314
315
            if self.logger_manager is not None and outputs.scheduler_stats is not None:
                self.logger_manager.record(
                    scheduler_stats=outputs.scheduler_stats,
                    iteration_stats=iteration_stats,
316
                    mm_cache_stats=self.input_processor.stat_mm_cache(),
317
318
                )
                self.do_log_stats_with_interval()
319

320
        return processed_outputs.request_outputs
321

322
    def start_profile(self):
323
        self.engine_core.profile(True)
324

325
    def stop_profile(self):
326
        self.engine_core.profile(False)
327

328
    def reset_mm_cache(self):
329
        self.input_processor.clear_mm_cache()
330
331
        self.engine_core.reset_mm_cache()

332
    def reset_prefix_cache(self):
333
334
        self.engine_core.reset_prefix_cache()

335
336
337
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

338
339
340
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

341
    def wake_up(self, tags: list[str] | None = None):
342
        self.engine_core.wake_up(tags)
343

344
345
346
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

347
348
349
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

350
351
352
353
    def get_metrics(self) -> list[Metric]:
        assert self.log_stats, "Stat logging disabled"
        return get_metrics_snapshot()

354
    @property
355
    def tokenizer(self) -> TokenizerLike | None:
356
        return self.input_processor.tokenizer
357
358

    @tokenizer.setter
359
    def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
360
        self.input_processor.tokenizer = tokenizer
361

362
    def get_tokenizer(self) -> TokenizerLike:
363
        if self.tokenizer is None:
364
            raise ValueError(
365
                "Unable to get tokenizer because `skip_tokenizer_init=True`"
366
            )
367

368
        return self.tokenizer
369

370
371
372
373
374
375
376
377
378
379
380
381
382
383
    def do_log_stats(self) -> None:
        """Log stats if logging is enabled."""
        if self.logger_manager:
            self.logger_manager.log()

    def do_log_stats_with_interval(self) -> None:
        """Log stats when the time interval has passed."""
        now = time.time()
        if not hasattr(self, "_last_log_time"):
            self._last_log_time = now
        if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
            self.do_log_stats()
            self._last_log_time = now

384
385
386
387
388
389
390
391
    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

392
    def list_loras(self) -> set[int]:
393
394
395
396
397
398
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
399

400
401
    def collective_rpc(
        self,
402
403
        method: str | Callable[[WorkerBase], _R],
        timeout: float | None = None,
404
        args: tuple = (),
405
        kwargs: dict[str, Any] | None = None,
406
    ) -> list[_R]:
407
408
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

409
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
410
        return self.collective_rpc("apply_model", args=(func,))
411

412
    def __del__(self):
413
414
415
416
        if (
            dp_group := getattr(self, "dp_group", None)
            and not self.external_launcher_dp
        ):
417
            stateless_destroy_torch_distributed_process_group(dp_group)