llm_engine.py 9.94 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Dict, List, Mapping, Optional, Type, Union
4

5
6
from typing_extensions import TypeVar

7
import vllm.envs as envs
8
from vllm.config import ParallelConfig, VllmConfig
9
10
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
11
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
12
13
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
14
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
15
from vllm.outputs import RequestOutput
16
17
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
18
from vllm.sampling_params import SamplingParams
19
20
from vllm.transformers_utils.tokenizer_group import (
    BaseTokenizerGroup, init_tokenizer_from_configs)
21
from vllm.usage.usage_lib import UsageContext
22
from vllm.v1.engine.core_client import EngineCoreClient
23
from vllm.v1.engine.output_processor import OutputProcessor
24
from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager
25
from vllm.v1.engine.processor import Processor
26
from vllm.v1.executor.abstract import Executor
27
28
29

logger = init_logger(__name__)

30
31
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)

32
33

class LLMEngine:
34
    """Legacy LLMEngine for backwards compatibility."""
35
36
37

    def __init__(
        self,
38
        vllm_config: VllmConfig,
39
        executor_class: Type[Executor],
40
41
42
43
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        input_registry: InputRegistry = INPUT_REGISTRY,
44
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
45
        use_cached_outputs: bool = False,
46
        multiprocess_mode: bool = False,
47
    ) -> None:
48
        self.vllm_config = vllm_config
49
        self.model_config = vllm_config.model_config
50
        self.cache_config = vllm_config.cache_config
51

52
53
54
        # Bookkeeping for parallel sampling requests
        self.parallel_manager = SyncParallelSamplingManager()

55
56
57
58
59
60
61
        # important: init dp group before init the engine_core
        self.parallel_config = vllm_config.parallel_config
        self.dp_enabled = self.parallel_config.data_parallel_size > 1  # noqa
        self.should_execute_dummy_batch = False
        if self.dp_enabled:
            self.dp_group = self.parallel_config.stateless_init_dp_group()

62
63
64
65
66
        # Tokenizer (+ ensure liveness if running in another process).
        self.tokenizer = init_tokenizer_from_configs(
            model_config=vllm_config.model_config,
            scheduler_config=vllm_config.scheduler_config,
            parallel_config=vllm_config.parallel_config,
67
            lora_config=vllm_config.lora_config)
68
69
70
        self.tokenizer.ping()

        # Processor (convert Inputs --> EngineCoreRequests)
71
72
73
74
75
76
        self.processor = Processor(model_config=vllm_config.model_config,
                                   cache_config=vllm_config.cache_config,
                                   lora_config=vllm_config.lora_config,
                                   tokenizer=self.tokenizer,
                                   input_registry=input_registry,
                                   mm_registry=mm_registry)
77

78
79
80
        # OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
        self.output_processor = OutputProcessor(self.tokenizer,
                                                log_stats=False)
81
82
83
84
85

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
86
87
            vllm_config=vllm_config,
            executor_class=executor_class,
88
            log_stats=False,  # FIXME: implement
89
        )
90

91
92
93
94
        if not multiprocess_mode:
            # for v0 compatibility
            self.model_executor = self.engine_core.engine_core.model_executor  # type: ignore

95
96
97
98
99
100
    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
101
        enable_multiprocessing: bool = False,
102
103
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
104

105
        # Create the engine configs.
106
        vllm_config = engine_args.create_engine_config(usage_context)
107
        executor_class = Executor.get_class(vllm_config)
108

109
        if envs.VLLM_ENABLE_V1_MULTIPROCESSING:
110
111
112
113
114
115
116
117
118
119
120
121
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
        return cls(vllm_config=vllm_config,
                   executor_class=executor_class,
                   log_stats=not engine_args.disable_log_stats,
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=enable_multiprocessing)

    def get_num_unfinished_requests(self) -> int:
122
123
        return self.parallel_manager.get_num_unfinished_requests(
            self.output_processor.get_num_unfinished_requests())
124
125

    def has_unfinished_requests(self) -> bool:
126
127
128
129
130
131
132
133
134
135
136
        has_unfinished = self.output_processor.has_unfinished_requests()
        if not self.dp_enabled:
            return has_unfinished
        return self.has_unfinished_requests_dp(has_unfinished)

    def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
        aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
            self.dp_group, has_unfinished)
        if not has_unfinished and aggregated_has_unfinished:
            self.should_execute_dummy_batch = True
        return aggregated_has_unfinished
137
138
139
140
141
142
143
144
145

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

    def abort_request(self, request_ids: List[str]) -> None:
        """Remove request_ids from EngineCore and Detokenizer."""

        self.engine_core.abort_requests(request_ids)
146
        self.output_processor.abort_requests(request_ids)
147

148
149
150
151
152
153
154
155
156
157
158
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
    ) -> None:
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        """Add request."""
        kwargs = dict(request_id=request_id,
                      prompt=prompt,
                      params=params,
                      arrival_time=arrival_time,
                      lora_request=lora_request,
                      trace_headers=trace_headers,
                      prompt_adapter_request=prompt_adapter_request,
                      priority=priority)
        # Handle parallel sampling requests differently.
        if params is None or isinstance(params,
                                        PoolingParams) or params.n == 1:
            self._add_request(**kwargs)
        else:
            # Special handling for parallel sampling requests
            self.parallel_manager.add_request_parallel_sampling(
                add_request=self._add_request, **kwargs)

    def _add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
    ) -> None:
        """Add request, `n=1`"""
189
        # 1) Process raw inputs into the request.
190
191
192
193
194
        request = self.processor.process_inputs(request_id, prompt, params,
                                                arrival_time, lora_request,
                                                trace_headers,
                                                prompt_adapter_request,
                                                priority)
195

196
197
        # 2) Make a new RequestState and queue.
        self.output_processor.add_request(request)
198

199
        # 3) Add the request to EngineCore.
200
        self.engine_core.add_request(request)
201
202
203

    def step(self) -> List[RequestOutput]:

204
205
206
207
208
        if self.should_execute_dummy_batch:
            self.should_execute_dummy_batch = False
            self.engine_core.execute_dummy_batch()
            return []

209
        # 1) Get EngineCoreOutput from the EngineCore.
210
        outputs = self.engine_core.get_output()
211

212
213
        # 2) Process EngineCoreOutputs.
        processed_outputs = self.output_processor.process_outputs(
214
            outputs.outputs)
215

216
217
        # 3) Abort any reqs that finished due to stop strings.
        self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
218

219
220
221
222
        request_outputs = processed_outputs.request_outputs

        # 4) Process unfinished parallel sampling requests
        return self.parallel_manager.step(request_outputs)
223

224
    def get_model_config(self):
225
        return self.model_config
226

227
    def start_profile(self):
228
        self.engine_core.profile(True)
229

230
    def stop_profile(self):
231
        self.engine_core.profile(False)
232

233
234
235
    def reset_prefix_cache(self):
        self.engine_core.reset_prefix_cache()

236
237
238
239
240
241
    def sleep(self, level: int = 1):
        self.engine_core.sleep(level)

    def wake_up(self):
        self.engine_core.wake_up()

242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    def get_tokenizer_group(
        self,
        group_type: Type[_G] = BaseTokenizerGroup,
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")

        return tokenizer_group