llm_engine.py 8.03 KB
Newer Older
1
from typing import Dict, List, Mapping, Optional, Type, Union
2

3
4
from typing_extensions import TypeVar

5
from vllm.config import VllmConfig
6
7
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
8
9
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
10
11
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
12
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
13
from vllm.outputs import RequestOutput
14
15
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
16
from vllm.sampling_params import SamplingParams
17
18
from vllm.transformers_utils.tokenizer_group import (
    BaseTokenizerGroup, init_tokenizer_from_configs)
19
from vllm.usage.usage_lib import UsageContext
20
21
22
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
23
from vllm.v1.executor.abstract import Executor
24
25
26

logger = init_logger(__name__)

27
28
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)

29
30

class LLMEngine:
31
    """Legacy LLMEngine for backwards compatibility."""
32
33
34

    def __init__(
        self,
35
        vllm_config: VllmConfig,
36
        executor_class: Type[Executor],
37
38
39
40
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        input_registry: InputRegistry = INPUT_REGISTRY,
41
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
42
        use_cached_outputs: bool = False,
43
        multiprocess_mode: bool = False,
44
    ) -> None:
45

46
47
48
49
50
51
52
53
        # TODO: Can we avoid this?
        self.model_config = vllm_config.model_config

        # Tokenizer (+ ensure liveness if running in another process).
        self.tokenizer = init_tokenizer_from_configs(
            model_config=vllm_config.model_config,
            scheduler_config=vllm_config.scheduler_config,
            parallel_config=vllm_config.parallel_config,
54
            lora_config=vllm_config.lora_config)
55
56
57
        self.tokenizer.ping()

        # Processor (convert Inputs --> EngineCoreRequests)
58
59
60
61
62
63
        self.processor = Processor(model_config=vllm_config.model_config,
                                   cache_config=vllm_config.cache_config,
                                   lora_config=vllm_config.lora_config,
                                   tokenizer=self.tokenizer,
                                   input_registry=input_registry,
                                   mm_registry=mm_registry)
64
65

        # Detokenizer (converts EngineCoreOutputs --> RequestOutput)
66
67
68
69
70
71
        self.detokenizer = Detokenizer(
            tokenizer_name=vllm_config.model_config.tokenizer,
            tokenizer_mode=vllm_config.model_config.tokenizer_mode,
            trust_remote_code=vllm_config.model_config.trust_remote_code,
            revision=vllm_config.model_config.tokenizer_revision,
        )
72
73
74
75
76

        # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=multiprocess_mode,
            asyncio_mode=False,
77
78
79
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_stats=False,
80
        )
81
82
83
84
85
86
87

    @classmethod
    def from_engine_args(
        cls,
        engine_args: EngineArgs,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
88
        enable_multiprocessing: bool = False,
89
90
    ) -> "LLMEngine":
        """Creates an LLM engine from the engine arguments."""
91

92
        # Create the engine configs.
93
        vllm_config = engine_args.create_engine_config(usage_context)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        executor_class = cls._get_executor_cls(vllm_config)

        if VLLM_ENABLE_V1_MULTIPROCESSING:
            logger.debug("Enabling multiprocessing for LLMEngine.")
            enable_multiprocessing = True

        # Create the LLMEngine.
        return cls(vllm_config=vllm_config,
                   executor_class=executor_class,
                   log_stats=not engine_args.disable_log_stats,
                   usage_context=usage_context,
                   stat_loggers=stat_loggers,
                   multiprocess_mode=enable_multiprocessing)

    @classmethod
109
110
    def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
        executor_class: Type[Executor]
111
112
        distributed_executor_backend = (
            vllm_config.parallel_config.distributed_executor_backend)
Rui Qiao's avatar
Rui Qiao committed
113
114
115
116
        if distributed_executor_backend == "ray":
            from vllm.v1.executor.ray_executor import RayExecutor
            executor_class = RayExecutor
        elif distributed_executor_backend == "mp":
117
118
119
120
121
122
123
124
            from vllm.v1.executor.multiproc_executor import MultiprocExecutor
            executor_class = MultiprocExecutor
        else:
            assert (distributed_executor_backend is None)
            from vllm.v1.executor.uniproc_executor import UniprocExecutor
            executor_class = UniprocExecutor

        return executor_class
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    def get_num_unfinished_requests(self) -> int:
        return self.detokenizer.get_num_unfinished_requests()

    def has_unfinished_requests(self) -> bool:
        return self.detokenizer.has_unfinished_requests()

    @classmethod
    def validate_outputs(cls, outputs, output_type):
        return outputs

    def abort_request(self, request_ids: List[str]) -> None:
        """Remove request_ids from EngineCore and Detokenizer."""

        self.engine_core.abort_requests(request_ids)
        self.detokenizer.abort_requests(request_ids)

142
143
144
145
146
147
148
149
150
151
152
153
    def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
    ) -> None:

154
        # 1) Process raw inputs into the request.
155
156
157
158
159
        request = self.processor.process_inputs(request_id, prompt, params,
                                                arrival_time, lora_request,
                                                trace_headers,
                                                prompt_adapter_request,
                                                priority)
160

161
        # 2) Add the request to Detokenizer.
162
        self.detokenizer.add_request(request)
163

164
        # 3) Add the request to EngineCore.
165
        self.engine_core.add_request(request)
166
167
168

    def step(self) -> List[RequestOutput]:

169
170
        # 1) Get EngineCoreOutput from the EngineCore.
        engine_core_outputs = self.engine_core.get_output()
171

172
173
174
        # 2) Detokenizer the EngineCoreOutput.
        request_outputs, requests_to_abort = self.detokenizer.step(
            engine_core_outputs)
175

176
177
178
        # 3) Abort requests that finished due to stopping criteria.
        if requests_to_abort:
            self.abort_request(requests_to_abort)
179

180
        return request_outputs
181

182
    # TODO(rob): Can we get rid of these?
183

184
    def get_model_config(self):
185
        return self.model_config
186

187
    def start_profile(self):
188
        self.engine_core.profile(True)
189

190
    def stop_profile(self):
191
        self.engine_core.profile(False)
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    def get_tokenizer_group(
        self,
        group_type: Type[_G] = BaseTokenizerGroup,
    ) -> _G:
        tokenizer_group = self.tokenizer

        if tokenizer_group is None:
            raise ValueError("Unable to get tokenizer because "
                             "skip_tokenizer_init is True")
        if not isinstance(tokenizer_group, group_type):
            raise TypeError("Invalid type of tokenizer group. "
                            f"Expected type: {group_type}, but "
                            f"found type: {type(tokenizer_group)}")

        return tokenizer_group
208
209
210
211
212
213
214

    def __del__(self):
        self.shutdown()

    def shutdown(self):
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()