llm_engine.py 13.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Mapping
5
from copy import copy
6
from typing import Any, Callable, Optional, Union
7

8
import torch.nn as nn
9
10
from typing_extensions import TypeVar

11
import vllm.envs as envs
12
from vllm.config import ParallelConfig, VllmConfig
13
from vllm.distributed import stateless_destroy_torch_distributed_process_group
14
from vllm.engine.arg_utils import EngineArgs
15
from vllm.inputs import PromptType
16
17
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
18
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
19
from vllm.outputs import PoolingRequestOutput, RequestOutput
20
from vllm.pooling_params import PoolingParams
21
from vllm.sampling_params import SamplingParams
22
from vllm.tasks import SupportedTask
23
from vllm.tracing import init_tracer
24
25
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               init_tokenizer_from_configs)
26
from vllm.usage.usage_lib import UsageContext
27
from vllm.utils import Device
28
from vllm.v1.engine.core_client import EngineCoreClient
29
from vllm.v1.engine.output_processor import OutputProcessor
30
from vllm.v1.engine.parallel_sampling import ParentRequest
31
from vllm.v1.engine.processor import Processor
32
from vllm.v1.executor.abstract import Executor
33
34
35
36
from vllm.v1.metrics.loggers import (PrometheusStatLogger, StatLoggerBase,
                                     StatLoggerFactory)
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
37
from vllm.v1.worker.worker_base import WorkerBase
38
39
40

logger = init_logger(__name__)

41
_R = TypeVar("_R", default=Any)
42

43
44

class LLMEngine:
45
    """Legacy LLMEngine for backwards compatibility."""
46
47
48

    def __init__(
        self,
49
        vllm_config: VllmConfig,
50
        executor_class: type[Executor],
51
52
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
53
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
54
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
55
        use_cached_outputs: bool = False,
56
        multiprocess_mode: bool = False,
57
    ) -> None:
58
59
60
61
62
63
64
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

65
66
67
68
69
        if stat_loggers is not None:
            raise NotImplementedError(
                "Passing StatLoggers to LLMEngine in V1 is not yet supported. "
                "Set VLLM_USE_V1=0 and file and issue on Github.")

70
        self.vllm_config = vllm_config
71
        self.observability_config = vllm_config.observability_config
72
        self.model_config = vllm_config.model_config
73
        self.cache_config = vllm_config.cache_config
74

75
76
77
78
79
        self.log_stats = log_stats
        self.stat_logger: Optional[StatLoggerBase] = None
        if self.log_stats:
            self.stat_logger = PrometheusStatLogger(vllm_config)

80
        # important: init dp group before init the engine_core
81
82
83
84
85
86
        # In the decoupled engine case this is handled in EngineCoreProc.
        parallel_config = vllm_config.parallel_config
        if not multiprocess_mode and parallel_config.data_parallel_size > 1:
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
87
88
        self.should_execute_dummy_batch = False

89
90
91
92
93
        if self.model_config.skip_tokenizer_init:
            self.tokenizer = None
        else:
            # Tokenizer (+ ensure liveness if running in another process).
            self.tokenizer = init_tokenizer_from_configs(
94
                model_config=vllm_config.model_config)
95
96

        # Processor (convert Inputs --> EngineCoreRequests)
97
        self.processor = Processor(vllm_config=vllm_config,
98
99
                                   tokenizer=self.tokenizer,
                                   mm_registry=mm_registry)
100

101
102
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
103
                                                log_stats=self.log_stats)
104
105
106
107
108
        if self.observability_config.otlp_traces_endpoint is not None:
            tracer = init_tracer(
                "vllm.llm_engine",
                self.observability_config.otlp_traces_endpoint)
            self.output_processor.tracer = tracer
109
110
111
112
113

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
114
115
            vllm_config=vllm_config,
            executor_class=executor_class,
116
            log_stats=self.log_stats,
117
        )
118

119
120
121
122
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

123
124
125
        # Don't keep the dummy data in memory
        self.reset_mm_cache()

126
127
128
129
130
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
131
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
132
133
134
135
136
137
138
139
140
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        return cls(vllm_config=vllm_config,
                   executor_class=Executor.get_class(vllm_config),
                   log_stats=(not disable_log_stats),
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)

141
142
143
144
145
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
146
        stat_loggers: Optional[list[StatLoggerFactory]] = None,
147
        enable_multiprocessing: bool = False,
148
149
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
150

151
        # Create the engine configs.
152
        vllm_config = engine_args.create_engine_config(usage_context)
153
        executor_class = Executor.get_class(vllm_config)
154

155
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
156
157
158
159
160
161
162
163
164
165
166
167
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
        return cls(vllm_config=vllm_config,
                   executor_class=executor_class,
                   log_stats=not engine_args.disable_log_stats,
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=enable_multiprocessing)

    def get_num_unfinished_requests(self) -> int:
168
        return self.output_processor.get_num_unfinished_requests()
169
170

    def has_unfinished_requests(self) -> bool:
171
        has_unfinished = self.output_processor.has_unfinished_requests()
172
        if self.dp_group is None:
173
            return has_unfinished or self.engine_core.dp_engines_running()
174
175
176
177
178
179
180
181
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
            self.dp_group, has_unfinished)
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
182
183
184
185
186

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

187
188
189
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        return self.engine_core.get_supported_tasks()

190
    def abort_request(self, request_ids: list[str]) -> None:
191
192
        """Remove request_ids from EngineCore and Detokenizer."""

193
        request_ids = self.output_processor.abort_requests(request_ids)
194
195
        self.engine_core.abort_requests(request_ids)

196
197
198
199
200
201
202
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
203
        tokenization_kwargs: Optional[dict[str, Any]] = None,
204
205
206
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
    ) -> None:
207
208
209
210
211
        # Validate the request_id type.
        if not isinstance(request_id, str):
            raise TypeError(
                f"request_id must be a string, got {type(request_id)}")

212
        # Process raw inputs into the request.
213
214
        prompt_str, request = self.processor.process_inputs(
            request_id, prompt, params, arrival_time, lora_request,
215
            tokenization_kwargs, trace_headers, priority)
216

217
        n = params.n if isinstance(params, SamplingParams) else 1
218

219
220
        if n == 1:
            # Make a new RequestState and queue.
221
            self.output_processor.add_request(request, prompt_str, None, 0)
222
            # Add the request to EngineCore.
223
            self.engine_core.add_request(request)
224
225
226
227
228
229
230
231
232
233
234
            return

        # Fan out child requests (for n>1).
        parent_req = ParentRequest(request_id, params)
        for idx in range(n):
            request_id, params = parent_req.get_child_info(idx)
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
            child_request.sampling_params = params

            # Make a new RequestState and queue.
235
236
            self.output_processor.add_request(child_request, prompt_str,
                                              parent_req, idx)
237
238
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
239

240
    def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:
241

242
243
244
245
246
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

247
        # 1) Get EngineCoreOutput from the EngineCore.
248
        outputs = self.engine_core.get_output()
249

250
        # 2) Process EngineCoreOutputs.
251
        iteration_stats = IterationStats() if self.log_stats else None
252
        processed_outputs = self.output_processor.process_outputs(
253
254
255
            outputs.outputs,
            engine_core_timestamp=outputs.timestamp,
            iteration_stats=iteration_stats)
256

257
258
        # 3) Abort any reqs that finished due to stop strings.
        self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
259

260
261
262
263
264
265
        # 4) Record stats
        if self.stat_logger is not None:
            assert outputs.scheduler_stats is not None
            self.stat_logger.record(scheduler_stats=outputs.scheduler_stats,
                                    iteration_stats=iteration_stats)

266
        return processed_outputs.request_outputs
267

268
269
270
    def get_vllm_config(self):
        return self.vllm_config

271
    def get_model_config(self):
272
        return self.model_config
273

274
    def start_profile(self):
275
        self.engine_core.profile(True)
276

277
    def stop_profile(self):
278
        self.engine_core.profile(False)
279

280
    def reset_mm_cache(self):
281
        self.processor.clear_cache()
282
283
        self.engine_core.reset_mm_cache()

284
    def reset_prefix_cache(self, device: Optional[Device] = None):
285
286
        self.engine_core.reset_prefix_cache()

287
288
289
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

290
291
    def wake_up(self, tags: Optional[list[str]] = None):
        self.engine_core.wake_up(tags)
292

293
294
295
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

296
297
298
299
    def get_metrics(self) -> list[Metric]:
        assert self.log_stats, "Stat logging disabled"
        return get_metrics_snapshot()

300
    def get_tokenizer(self) -> AnyTokenizer:
301
        if self.tokenizer is None:
302
303
304
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")

305
        return self.tokenizer
306
307
308
309
310
311
312
313
314

    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

315
    def list_loras(self) -> set[int]:
316
317
318
319
320
321
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
322

323
    def collective_rpc(self,
324
                       method: Union[str, Callable[[WorkerBase], _R]],
325
326
327
328
329
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

330
331
332
    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
        return self.collective_rpc("apply_model", args=(func, ))

333
334
335
    def __del__(self):
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)