llm_engine.py 11.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from collections.abc import Mapping
4
from copy import copy
5
from typing import Any, Callable, Optional, Union
6

7
8
from typing_extensions import TypeVar

9
import vllm.envs as envs
10
from vllm.config import ParallelConfig, VllmConfig
11
from vllm.distributed import stateless_destroy_torch_distributed_process_group
12
13
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
14
from vllm.inputs import PromptType
15
16
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
17
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
18
from vllm.outputs import RequestOutput
19
20
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
21
from vllm.sampling_params import SamplingParams
22
23
from vllm.transformers_utils.tokenizer_group import (
    BaseTokenizerGroup, init_tokenizer_from_configs)
24
from vllm.usage.usage_lib import UsageContext
25
from vllm.utils import Device
26
from vllm.v1.engine.core_client import EngineCoreClient
27
from vllm.v1.engine.output_processor import OutputProcessor
28
from vllm.v1.engine.parallel_sampling import ParentRequest
29
from vllm.v1.engine.processor import Processor
30
from vllm.v1.executor.abstract import Executor
31
from vllm.v1.utils import report_usage_stats
32
33
34

logger = init_logger(__name__)

35
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
36
_R = TypeVar("_R", default=Any)
37

38
39

class LLMEngine:
40
    """Legacy LLMEngine for backwards compatibility."""
41
42
43

    def __init__(
        self,
44
        vllm_config: VllmConfig,
45
        executor_class: type[Executor],
46
47
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
48
        stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
49
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
50
        use_cached_outputs: bool = False,
51
        multiprocess_mode: bool = False,
52
    ) -> None:
53
54
55
56
57
58
59
        if not envs.VLLM_USE_V1:
            raise ValueError(
                "Using V1 LLMEngine, but envs.VLLM_USE_V1=False. "
                "This should not happen. As a workaround, try using "
                "LLMEngine.from_vllm_config(...) or explicitly set "
                "VLLM_USE_V1=0 or 1 and report this issue on Github.")

60
        self.vllm_config = vllm_config
61
        self.model_config = vllm_config.model_config
62
        self.cache_config = vllm_config.cache_config
63

64
        # important: init dp group before init the engine_core
65
66
67
68
69
70
        # In the decoupled engine case this is handled in EngineCoreProc.
        parallel_config = vllm_config.parallel_config
        if not multiprocess_mode and parallel_config.data_parallel_size > 1:
            self.dp_group = parallel_config.stateless_init_dp_group()
        else:
            self.dp_group = None
71
72
        self.should_execute_dummy_batch = False

73
74
75
76
77
        # Tokenizer (+ ensure liveness if running in another process).
        self.tokenizer = init_tokenizer_from_configs(
            model_config=vllm_config.model_config,
            scheduler_config=vllm_config.scheduler_config,
            parallel_config=vllm_config.parallel_config,
78
            lora_config=vllm_config.lora_config)
79
80
81
        self.tokenizer.ping()

        # Processor (convert Inputs --> EngineCoreRequests)
82
        self.processor = Processor(vllm_config=vllm_config,
83
84
                                   tokenizer=self.tokenizer,
                                   mm_registry=mm_registry)
85

86
87
88
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
                                                log_stats=False)
89
90
91
92
93

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
94
95
            vllm_config=vllm_config,
            executor_class=executor_class,
96
            log_stats=False,  # FIXME: implement
97
        )
98

99
100
101
102
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

103
104
105
        # If usage stat is enabled, collect relevant info.
        report_usage_stats(vllm_config, usage_context)

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    @classmethod
    def from_vllm_config(
        cls,
        vllm_config: VllmConfig,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
        disable_log_stats: bool = False,
    ) -> "LLMEngine":
        if stat_loggers is not None:
            raise NotImplementedError(
                "Passing StatLoggers to V1 is not yet supported. "
                "Set VLLM_USE_V1=0 and file and issue on Github.")

        return cls(vllm_config=vllm_config,
                   executor_class=Executor.get_class(vllm_config),
                   log_stats=(not disable_log_stats),
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=envs.VLLM_ENABLE_V1_MULTIPROCESSING)

126
127
128
129
130
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
131
        stat_loggers: Optional[dict[str, StatLoggerBase]] = None,
132
        enable_multiprocessing: bool = False,
133
134
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
135

136
        # Create the engine configs.
137
        vllm_config = engine_args.create_engine_config(usage_context)
138
        executor_class = Executor.get_class(vllm_config)
139

140
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
141
142
143
144
145
146
147
148
149
150
151
152
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
        return cls(vllm_config=vllm_config,
                   executor_class=executor_class,
                   log_stats=not engine_args.disable_log_stats,
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=enable_multiprocessing)

    def get_num_unfinished_requests(self) -> int:
153
        return self.output_processor.get_num_unfinished_requests()
154
155

    def has_unfinished_requests(self) -> bool:
156
        has_unfinished = self.output_processor.has_unfinished_requests()
157
        if self.dp_group is None:
158
159
160
161
162
163
164
165
166
            return has_unfinished
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
            self.dp_group, has_unfinished)
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
167
168
169
170
171

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

172
    def abort_request(self, request_ids: list[str]) -> None:
173
174
        """Remove request_ids from EngineCore and Detokenizer."""

175
        request_ids = self.output_processor.abort_requests(request_ids)
176
177
        self.engine_core.abort_requests(request_ids)

178
179
180
181
182
183
184
185
186
187
188
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
    ) -> None:
189
190
191
192
193
194
        # Process raw inputs into the request.
        request = self.processor.process_inputs(request_id, prompt, params,
                                                arrival_time, lora_request,
                                                trace_headers,
                                                prompt_adapter_request,
                                                priority)
195

196
        n = params.n if isinstance(params, SamplingParams) else 1
197

198
199
200
201
        if n == 1:
            # Make a new RequestState and queue.
            self.output_processor.add_request(request, None, 0)
            # Add the request to EngineCore.
202
            self.engine_core.add_request(request)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
            return

        # Fan out child requests (for n>1).
        parent_req = ParentRequest(request_id, params)
        for idx in range(n):
            request_id, params = parent_req.get_child_info(idx)
            child_request = request if idx == n - 1 else copy(request)
            child_request.request_id = request_id
            child_request.sampling_params = params

            # Make a new RequestState and queue.
            self.output_processor.add_request(child_request, parent_req, idx)
            # Add the request to EngineCore.
            self.engine_core.add_request(child_request)
217

218
    def step(self) -> list[RequestOutput]:
219

220
221
222
223
224
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

225
        # 1) Get EngineCoreOutput from the EngineCore.
226
        outputs = self.engine_core.get_output()
227

228
229
        # 2) Process EngineCoreOutputs.
        processed_outputs = self.output_processor.process_outputs(
230
            outputs.outputs)
231

232
233
        # 3) Abort any reqs that finished due to stop strings.
        self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
234

235
        return processed_outputs.request_outputs
236

237
238
239
    def get_vllm_config(self):
        return self.vllm_config

240
    def get_model_config(self):
241
        return self.model_config
242

243
    def start_profile(self):
244
        self.engine_core.profile(True)
245

246
    def stop_profile(self):
247
        self.engine_core.profile(False)
248

249
    def reset_prefix_cache(self, device: Optional[Device] = None):
250
251
        self.engine_core.reset_prefix_cache()

252
253
254
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

255
256
    def wake_up(self, tags: Optional[list[str]] = None):
        self.engine_core.wake_up(tags)
257

258
259
260
    def is_sleeping(self) -> bool:
        return self.engine_core.is_sleeping()

261
262
    def get_tokenizer_group(
        self,
263
        group_type: type[_G] = BaseTokenizerGroup,
264
265
266
267
268
269
270
271
272
273
274
275
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")

        return tokenizer_group
276
277
278
279
280
281
282
283
284

    def add_lora(self, lora_request: LoRARequest) -> bool:
        """Load a new LoRA adapter into the engine for future requests."""
        return self.engine_core.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        """Remove an already loaded LoRA adapter."""
        return self.engine_core.remove_lora(lora_id)

285
    def list_loras(self) -> set[int]:
286
287
288
289
290
291
        """List all registered adapters."""
        return self.engine_core.list_loras()

    def pin_lora(self, lora_id: int) -> bool:
        """Prevent an adapter from being evicted."""
        return self.engine_core.pin_lora(lora_id)
292

293
294
295
296
297
298
299
    def collective_rpc(self,
                       method: Union[str, Callable[..., _R]],
                       timeout: Optional[float] = None,
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
        return self.engine_core.collective_rpc(method, timeout, args, kwargs)

300
301
302
    def __del__(self):
        if dp_group := getattr(self, "dp_group", None):
            stateless_destroy_torch_distributed_process_group(dp_group)