llm_engine.py 14.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections.abc import Mapping
6
from copy import copy
7
from typing import Any, Callable, Optional, Union
8

9
import torch.nn as nn
10
11
from typing_extensions import TypeVar

12
import vllm.envs as envs
13
from vllm.config import ParallelConfig, VllmConfig
14
from vllm.distributed import stateless_destroy_torch_distributed_process_group
15
from vllm.distributed.parallel_state import get_dp_group
16
from vllm.engine.arg_utils import EngineArgs
17
from vllm.inputs import PromptType
18
19
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
20
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
21
from vllm.outputs import PoolingRequestOutput, RequestOutput
22
from vllm.pooling_params import PoolingParams
23
from vllm.sampling_params import SamplingParams
24
from vllm.tasks import SupportedTask
25
from vllm.tracing import init_tracer
26
27
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               init_tokenizer_from_configs)
28
from vllm.usage.usage_lib import UsageContext
29
from vllm.utils import Device
30
from vllm.v1.engine.core_client import EngineCoreClient
31
from vllm.v1.engine.output_processor import OutputProcessor
32
from vllm.v1.engine.parallel_sampling import ParentRequest
33
from vllm.v1.engine.processor import Processor
34
from vllm.v1.executor.abstract import Executor
35
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
36
37
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
38
from vllm.v1.worker.worker_base import WorkerBase
39
40
41

logger = init_logger(__name__)

42
_R = TypeVar("_R", default=Any)
43

44
45

class LLMEngine:
46
    """Legacy LLMEngine for backwards compatibility."""
47
48
49

    def __init__(
        self,
50
        vllm_config: VllmConfig,
51
        executor_class: type[Executor],
52
53
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
54
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
55
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
56
        use_cached_outputs: bool = False,
57
        multiprocess_mode: bool = False,
58
    ) -> None:
59
60
61
62
63
64
65
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

66
67
68
69
70
        if stat_loggers is not None:
            raise NotImplementedError(
                "Passing StatLoggers to LLMEngine in V1 is not yet supported. "
                "Set VLLM_USE_V1=0 and file and issue on Github.")

71
        self.vllm_config = vllm_config
72
        self.observability_config = vllm_config.observability_config
73
        self.model_config = vllm_config.model_config
74
        self.cache_config = vllm_config.cache_config
75

76
77
        self.log_stats = log_stats

78
79
80
81
82
        executor_backend = (
            self.vllm_config.parallel_config.distributed_executor_backend)
        parallel_config = vllm_config.parallel_config
        self.external_launcher_dp = (parallel_config.data_parallel_size > 1 and
                                     executor_backend == "external_launcher")
83
        # important: init dp group before init the engine_core
84
        # In the decoupled engine case this is handled in EngineCoreProc.
85
86
        if not multiprocess_mode and parallel_config.data_parallel_size > 1 \
            and not self.external_launcher_dp:
87
88
89
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
90
91
        self.should_execute_dummy_batch = False

92
93
94
95
96
        if self.model_config.skip_tokenizer_init:
            self.tokenizer = None
        else:
            # Tokenizer (+ ensure liveness if running in another process).
            self.tokenizer = init_tokenizer_from_configs(
97
                model_config=vllm_config.model_config)
98
99

        # Processor (convert Inputs --> EngineCoreRequests)
100
        self.processor = Processor(vllm_config=vllm_config,
101
102
                                   tokenizer=self.tokenizer,
                                   mm_registry=mm_registry)
103

104
105
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
106
                                                log_stats=self.log_stats)
107
108
109
110
111
        if self.observability_config.otlp_traces_endpoint is not None:
            tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)
            self.output_processor.tracer = tracer
112
113
114
115
116

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
117
118
            vllm_config=vllm_config,
            executor_class=executor_class,
119
            log_stats=self.log_stats,
120
        )
121

122
123
124
125
126
127
128
129
130
        self.logger_manager: Optional[StatLoggerManager] = None
        if self.log_stats:
            self.logger_manager = StatLoggerManager(
                vllm_config=vllm_config,
                custom_stat_loggers=stat_loggers,
                enable_default_loggers=log_stats,
            )
            self.logger_manager.log_engine_initialized()

131
132
133
134
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

135
136
137
138
139
        if self.external_launcher_dp:
            # If we use DP in external launcher mode, we reuse the
            # existing DP group used for data communication.
            self.dp_group = get_dp_group().cpu_group

140
141
142
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

143
144
145
146
147
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
148
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
149
150
151
152
153
154
155
156
157
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(vllm_config=vllm_config,
                   executor_class=Executor.get_class(vllm_config),
                   log_stats=(not disable_log_stats),
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)

158
159
160
161
162
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
163
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
164
        enable_multiprocessing: bool = False,
165
166
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
167

168
        # Create the engine configs.
169
        vllm_config = engine_args.create_engine_config(usage_context)
170
        executor_class = Executor.get_class(vllm_config)
171

172
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
173
174
175
176
177
178
179
180
181
182
183
184
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
        return cls(vllm_config=vllm_config,
                   executor_class=executor_class,
                   log_stats=not engine_args.disable_log_stats,
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=enable_multiprocessing)

    def get_num_unfinished_requests(self) -> int:
185
        return self.output_processor.get_num_unfinished_requests()
186
187

    def has_unfinished_requests(self) -> bool:
188
        has_unfinished = self.output_processor.has_unfinished_requests()
189
        if self.dp_group is None:
190
            return has_unfinished or self.engine_core.dp_engines_running()
191
192
193
194
195
196
197
198
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
            self.dp_group, has_unfinished)
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
199
200
201
202
203

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

204
205
206
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.engine_core.get_supported_tasks()

207
    def abort_request(self, request_ids: list[str]) -> None:
208
209
        """Remove request_ids from EngineCore and Detokenizer."""

210
        request_ids = self.output_processor.abort_requests(request_ids)
211
212
        self.engine_core.abort_requests(request_ids)

213
214
215
216
217
218
219
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
220
        tokenization_kwargs: Optional[dict[str, Any]] = None,
221
222
223
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
    ) -> None:
224
225
226
227
228
        # Validate the request_id type.
        if not isinstance(request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request_id)}")

229
        # Process raw inputs into the request.
230
231
232
233
234
235
        request = self.processor.process_inputs(request_id, prompt, params,
                                                arrival_time, lora_request,
                                                tokenization_kwargs,
                                                trace_headers, priority)
        prompt_text = prompt if isinstance(prompt,
                                           str) else prompt.get("prompt")
236

237
        n = params.n if isinstance(params, SamplingParams) else 1
238

239
240
        if n == 1:
            # Make a new RequestState and queue.
241
            self.output_processor.add_request(request, prompt_text, None, 0)
242
            # Add the request to EngineCore.
243
            self.engine_core.add_request(request)
244
245
246
247
248
249
250
251
252
253
254
            return

        # Fan out child requests (for n>1).
        parent_req = ParentRequest(request_id, params)
        for idx in range(n):
            request_id, params = parent_req.get_child_info(idx)
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
            child_request.sampling_params = params

            # Make a new RequestState and queue.
255
            self.output_processor.add_request(child_request, prompt_text,
256
                                              parent_req, idx)
257
258
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
259

260
    def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:
261

262
263
264
265
266
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

267
        # 1) Get EngineCoreOutput from the EngineCore.
268
        outputs = self.engine_core.get_output()
269

270
        # 2) Process EngineCoreOutputs.
271
        iteration_stats = IterationStats() if self.log_stats else None
272
        processed_outputs = self.output_processor.process_outputs(
273
274
275
            outputs.outputs,
            engine_core_timestamp=outputs.timestamp,
            iteration_stats=iteration_stats)
276

277
278
        # 3) Abort any reqs that finished due to stop strings.
        self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
279

280
        # 4) Record stats
281
        if self.logger_manager is not None:
282
            assert outputs.scheduler_stats is not None
283
284
285
286
287
            self.logger_manager.record(
                scheduler_stats=outputs.scheduler_stats,
                iteration_stats=iteration_stats,
            )
            self.do_log_stats_with_interval()
288

289
        return processed_outputs.request_outputs
290

291
292
293
    def get_vllm_config(self):
        return self.vllm_config

294
    def get_model_config(self):
295
        return self.model_config
296

297
    def start_profile(self):
298
        self.engine_core.profile(True)
299

300
    def stop_profile(self):
301
        self.engine_core.profile(False)
302

303
    def reset_mm_cache(self):
304
        self.processor.clear_cache()
305
306
        self.engine_core.reset_mm_cache()

307
    def reset_prefix_cache(self, device: Optional[Device] = None):
308
309
        self.engine_core.reset_prefix_cache()

310
311
312
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

313
314
    def wake_up(self, tags: Optional[list[str]] = None):
        self.engine_core.wake_up(tags)
315

316
317
318
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

319
320
321
322
    def get_metrics(self) -> list[Metric]:
        assert self.log_stats, "Stat logging disabled"
        return get_metrics_snapshot()

323
    def get_tokenizer(self) -> AnyTokenizer:
324
        if self.tokenizer is None:
325
326
327
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")

328
        return self.tokenizer
329

330
331
332
333
334
335
336
337
338
339
340
341
342
343
    def do_log_stats(self) -> None:
        """Log stats if logging is enabled."""
        if self.logger_manager:
            self.logger_manager.log()

    def do_log_stats_with_interval(self) -> None:
        """Log stats when the time interval has passed."""
        now = time.time()
        if not hasattr(self, "_last_log_time"):
            self._last_log_time = now
        if now - self._last_log_time >= envs.VLLM_LOG_STATS_INTERVAL:
            self.do_log_stats()
            self._last_log_time = now

344
345
346
347
348
349
350
351
    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

352
    def list_loras(self) -> set[int]:
353
354
355
356
357
358
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
359

360
    def collective_rpc(self,
361
                       method: Union[str, Callable[[WorkerBase], _R]],
362
363
364
365
366
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

367
368
369
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        return self.collective_rpc("apply_model", args=(func, ))

370
    def __del__(self):
371
372
        if dp_group := getattr(self, "dp_group",
                               None) and not self.external_launcher_dp:
373
            stateless_destroy_torch_distributed_process_group(dp_group)