"tests/entrypoints/openai/test_chat.py" did not exist on "67b4221a61ace91a79aff507df0a95a01978300e"
llm_engine.py 16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Callable, Mapping
6
from copy import copy
7
from typing import Any
8

9
import torch.nn as nn
10
from typing_extensions import TypeVar
11

12
import vllm.envs as envs
13
from vllm.config import ParallelConfig, VllmConfig
14
from vllm.distributed import stateless_destroy_torch_distributed_process_group
15
from vllm.distributed.parallel_state import get_dp_group
16
from vllm.engine.arg_utils import EngineArgs
17
from vllm.inputs import PromptType
18
19
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
20
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
21
from vllm.outputs import PoolingRequestOutput, RequestOutput
22
from vllm.plugins.io_processors import get_io_processor
23
from vllm.pooling_params import PoolingParams
24
from vllm.renderers import BaseRenderer
25
26
from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import extract_prompt_components
27
from vllm.sampling_params import SamplingParams
28
from vllm.tasks import SupportedTask
29
from vllm.tokenizers import TokenizerLike
30
from vllm.tracing import init_tracer
31
from vllm.usage.usage_lib import UsageContext
32
from vllm.v1.engine import EngineCoreRequest
33
from vllm.v1.engine.core_client import EngineCoreClient
34
from vllm.v1.engine.input_processor import InputProcessor
35
from vllm.v1.engine.output_processor import OutputProcessor
36
from vllm.v1.engine.parallel_sampling import ParentRequest
37
from vllm.v1.executor import Executor
38
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
39
40
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
41
from vllm.v1.utils import record_function_or_nullcontext
42
from vllm.v1.worker.worker_base import WorkerBase
43
44
45

logger = init_logger(__name__)

46
_R = TypeVar("_R", default=Any)
47

48
49

class LLMEngine:
50
    """Legacy LLMEngine for backwards compatibility."""
51
52
53

    def __init__(
        self,
54
        vllm_config: VllmConfig,
55
        executor_class: type[Executor],
56
        log_stats: bool,
57
        aggregate_engine_logging: bool = False,
58
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
59
        stat_loggers: list[StatLoggerFactory] | None = None,
60
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
61
        use_cached_outputs: bool = False,
62
        multiprocess_mode: bool = False,
63
    ) -> None:
64
        self.vllm_config = vllm_config
65
        self.observability_config = vllm_config.observability_config
66
        self.model_config = vllm_config.model_config
67
        self.cache_config = vllm_config.cache_config
68

69
70
        self.log_stats = log_stats

71
        parallel_config = vllm_config.parallel_config
72
73
        executor_backend = parallel_config.distributed_executor_backend

74
75
76
77
        self.external_launcher_dp = (
            parallel_config.data_parallel_size > 1
            and executor_backend == "external_launcher"
        )
78
        # important: init dp group before init the engine_core
79
        # In the decoupled engine case this is handled in EngineCoreProc.
80
81
82
83
84
        if (
            not multiprocess_mode
            and parallel_config.data_parallel_size > 1
            and not self.external_launcher_dp
        ):
85
86
87
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
88
89
        self.should_execute_dummy_batch = False

90
        self.input_processor = InputProcessor(self.vllm_config)
91
92
        self.io_processor = get_io_processor(
            self.vllm_config,
93
            self.model_config.io_processor_plugin,
94
        )
95

96
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
97
        self.output_processor = OutputProcessor(
98
99
100
            self.tokenizer,
            log_stats=self.log_stats,
            stream_interval=self.vllm_config.scheduler_config.stream_interval,
101
        )
102
103
        endpoint = self.observability_config.otlp_traces_endpoint
        if endpoint is not None:
104
105
            init_tracer("vllm.llm_engine", endpoint)
            self.output_processor.tracing_enabled = True
106
107
108
109
110

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
111
112
            vllm_config=vllm_config,
            executor_class=executor_class,
113
            log_stats=self.log_stats,
114
        )
115

116
        self.logger_manager: StatLoggerManager | None = None
117
118
119
120
121
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
                custom_stat_loggers=stat_loggers,
                enable_default_loggers=log_stats,
122
                aggregate_engine_logging=aggregate_engine_logging,
123
124
125
            )
            self.logger_manager.log_engine_initialized()

126
127
128
129
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

130
131
132
133
134
        if self.external_launcher_dp:
            # If we use DP in external launcher mode, we reuse the
            # existing DP group used for data communication.
            self.dp_group = get_dp_group().cpu_group

135
136
137
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

138
139
140
141
142
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
143
        stat_loggers: list[StatLoggerFactory] | None = None,
144
145
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
146
147
148
149
150
151
152
153
        return cls(
            vllm_config=vllm_config,
            executor_class=Executor.get_class(vllm_config),
            log_stats=(not disable_log_stats),
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING,
        )
154

155
156
157
158
159
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
160
        stat_loggers: list[StatLoggerFactory] | None = None,
161
        enable_multiprocessing: bool = False,
162
163
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
164

165
        # Create the engine configs.
166
        vllm_config = engine_args.create_engine_config(usage_context)
167
        executor_class = Executor.get_class(vllm_config)
168

169
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
170
171
172
173
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
174
175
176
177
178
179
180
181
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_stats=not engine_args.disable_log_stats,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
            multiprocess_mode=enable_multiprocessing,
        )
182
183

    def get_num_unfinished_requests(self) -> int:
184
        return self.output_processor.get_num_unfinished_requests()
185
186

    def has_unfinished_requests(self) -> bool:
187
        has_unfinished = self.output_processor.has_unfinished_requests()
188
        if self.dp_group is None:
189
            return has_unfinished or self.engine_core.dp_engines_running()
190
191
192
193
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
194
195
            self.dp_group, has_unfinished
        )
196
197
198
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
199
200
201
202
203

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

204
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
205
206
207
208
209
        if not hasattr(self, "_supported_tasks"):
            # Cache the result
            self._supported_tasks = self.engine_core.get_supported_tasks()

        return self._supported_tasks
210

211
    def abort_request(self, request_ids: list[str], internal: bool = False) -> None:
212
213
        """Remove request_ids from EngineCore and Detokenizer."""

214
        request_ids = self.output_processor.abort_requests(request_ids, internal)
215
216
        self.engine_core.abort_requests(request_ids)

217
218
219
    def add_request(
        self,
        request_id: str,
220
        prompt: EngineCoreRequest | PromptType | DictPrompt | TokPrompt,
221
222
223
224
225
        params: SamplingParams | PoolingParams,
        arrival_time: float | None = None,
        lora_request: LoRARequest | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        trace_headers: Mapping[str, str] | None = None,
226
        priority: int = 0,
227
        prompt_text: str | None = None,
228
    ) -> None:
229
230
        # Validate the request_id type.
        if not isinstance(request_id, str):
231
            raise TypeError(f"request_id must be a string, got {type(request_id)}")
232

233
        # Process raw inputs into the request.
234
235
        if isinstance(prompt, EngineCoreRequest):
            request = prompt
236
237
238
239
240
241
            if request_id != request.request_id:
                logger.warning_once(
                    "AsyncLLM.add_request() was passed a request_id parameter that "
                    "does not match the EngineCoreRequest.request_id attribute. The "
                    "latter will be used, and the former will be ignored."
                )
242
243
        else:
            assert prompt_text is None
244
            request = self.input_processor.process_inputs(
245
246
247
248
249
250
251
252
                request_id,
                prompt,
                params,
                arrival_time,
                lora_request,
                tokenization_kwargs,
                trace_headers,
                priority,
253
                supported_tasks=self.get_supported_tasks(),
254
            )
255
            prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
256

257
258
        self.input_processor.assign_request_id(request)

259
260
261
        # Use cloned params that may have been updated in process_inputs()
        params = request.params

262
        n = params.n if isinstance(params, SamplingParams) else 1
263

264
265
        if n == 1:
            # Make a new RequestState and queue.
266
            self.output_processor.add_request(request, prompt_text, None, 0)
267
            # Add the request to EngineCore.
268
            self.engine_core.add_request(request)
269
270
271
            return

        # Fan out child requests (for n>1).
272
        parent_req = ParentRequest(request)
273
        for idx in range(n):
274
            request_id, child_params = parent_req.get_child_info(idx)
275
276
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
277
            child_request.sampling_params = child_params
278
279

            # Make a new RequestState and queue.
280
281
282
            self.output_processor.add_request(
                child_request, prompt_text, parent_req, idx
            )
283
284
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
285

286
    def step(self) -> list[RequestOutput | PoolingRequestOutput]:
287
288
289
290
291
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

292
        # 1) Get EngineCoreOutput from the EngineCore.
293
        with record_function_or_nullcontext("llm_engine step: get_output"):
294
            outputs = self.engine_core.get_output()
295

296
        # 2) Process EngineCoreOutputs.
297
        with record_function_or_nullcontext("llm_engine step: process_outputs"):
298
299
300
301
302
303
304
            iteration_stats = IterationStats() if self.log_stats else None
            processed_outputs = self.output_processor.process_outputs(
                outputs.outputs,
                engine_core_timestamp=outputs.timestamp,
                iteration_stats=iteration_stats,
            )
            self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
305

306
        # 3) Abort any reqs that finished due to stop strings.
307
        with record_function_or_nullcontext("llm_engine step: abort_requests"):
308
            self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
309

310
        # 4) Record stats
311
        with record_function_or_nullcontext("llm_engine step: record_stats"):
312
313
314
315
            if self.logger_manager is not None and outputs.scheduler_stats is not None:
                self.logger_manager.record(
                    scheduler_stats=outputs.scheduler_stats,
                    iteration_stats=iteration_stats,
316
                    mm_cache_stats=self.input_processor.stat_mm_cache(),
317
318
                )
                self.do_log_stats_with_interval()
319

320
        return processed_outputs.request_outputs
321

322
    def start_profile(self):
323
        self.engine_core.profile(True)
324

325
    def stop_profile(self):
326
        self.engine_core.profile(False)
327

328
    def reset_mm_cache(self):
329
        self.input_processor.clear_mm_cache()
330
331
        self.engine_core.reset_mm_cache()

332
333
334
335
336
337
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.engine_core.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
338

339
340
341
342
343
344
345
346
    def reset_encoder_cache(self) -> None:
        """Reset the encoder cache to invalidate all cached encoder outputs.

        This should be called when model weights are updated to ensure
        stale vision embeddings computed with old weights are not reused.
        """
        self.engine_core.reset_encoder_cache()

347
348
349
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

350
351
352
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(1, level)

353
    def wake_up(self, tags: list[str] | None = None):
354
        self.engine_core.wake_up(tags)
355

356
357
358
        if self.logger_manager is not None:
            self.logger_manager.record_sleep_state(0, 0)

359
360
361
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

362
363
364
365
    def get_metrics(self) -> list[Metric]:
        assert self.log_stats, "Stat logging disabled"
        return get_metrics_snapshot()

366
    @property
367
    def tokenizer(self) -> TokenizerLike | None:
368
        return self.input_processor.tokenizer
369

370
    def get_tokenizer(self) -> TokenizerLike:
371
        return self.input_processor.get_tokenizer()
372

373
    @property
374
    def renderer(self) -> BaseRenderer:
375
        return self.input_processor.renderer
376

377
378
379
380
381
382
383
384
385
386
387
388
389
390
    def do_log_stats(self) -> None:
        """Log stats if logging is enabled."""
        if self.logger_manager:
            self.logger_manager.log()

    def do_log_stats_with_interval(self) -> None:
        """Log stats when the time interval has passed."""
        now = time.time()
        if not hasattr(self, "_last_log_time"):
            self._last_log_time = now
        if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
            self.do_log_stats()
            self._last_log_time = now

391
392
393
394
395
396
397
398
    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

399
    def list_loras(self) -> set[int]:
400
401
402
403
404
405
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
406

407
408
    def collective_rpc(
        self,
409
410
        method: str | Callable[[WorkerBase], _R],
        timeout: float | None = None,
411
        args: tuple = (),
412
        kwargs: dict[str, Any] | None = None,
413
    ) -> list[_R]:
414
415
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

416
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
417
        return self.collective_rpc("apply_model", args=(func,))
418

419
    def __del__(self):
420
421
        dp_group = getattr(self, "dp_group", None)
        if dp_group is not None and not self.external_launcher_dp:
422
            stateless_destroy_torch_distributed_process_group(dp_group)