async_llm.py 13 KB
Newer Older
1
2
3
4
5
6
7
8
import asyncio
from typing import AsyncGenerator, Dict, List, Mapping, Optional, Type, Union

from vllm.config import ModelConfig, VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.engine.protocol import EngineClient
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
9
from vllm.inputs.preprocess import InputPreprocessor
10
11
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
12
from vllm.outputs import RequestOutput
13
14
15
16
17
18
19
20
21
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.detokenizer import Detokenizer
from vllm.v1.engine.processor import Processor
22
from vllm.v1.executor.abstract import Executor
23
24
25
26
27
28
29
30
31

logger = init_logger(__name__)


class AsyncLLM(EngineClient):

    def __init__(
        self,
        vllm_config: VllmConfig,
32
        executor_class: Type[Executor],
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        log_stats: bool,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
        input_registry: InputRegistry = INPUT_REGISTRY,
        use_cached_outputs: bool = False,
        log_requests: bool = True,
        start_engine_loop: bool = True,
    ) -> None:
        assert start_engine_loop

        self.log_requests = log_requests
        self.log_stats = log_stats
        self.stat_loggers = stat_loggers
        self.model_config = vllm_config.model_config

        # Tokenizer (+ ensure liveness if running in another process).
        self.tokenizer = init_tokenizer_from_configs(
            model_config=vllm_config.model_config,
            scheduler_config=vllm_config.scheduler_config,
            parallel_config=vllm_config.parallel_config,
53
            lora_config=vllm_config.lora_config)
54
55
        self.tokenizer.ping()

56
57
        # Request streams (map of request_id -> queue).
        self.rid_to_queue: Dict[str, asyncio.Queue] = {}
58
59

        # Processor (converts Inputs --> EngineCoreRequests).
60
61
62
63
64
65
66
        self.processor = Processor(
            model_config=vllm_config.model_config,
            cache_config=vllm_config.cache_config,
            lora_config=vllm_config.lora_config,
            tokenizer=self.tokenizer,
            input_registry=input_registry,
        )
67
68

        # Detokenizer (converts EngineCoreOutputs --> RequestOutput).
69
70
71
72
73
74
        self.detokenizer = Detokenizer(
            tokenizer_name=vllm_config.model_config.tokenizer,
            tokenizer_mode=vllm_config.model_config.tokenizer_mode,
            trust_remote_code=vllm_config.model_config.trust_remote_code,
            revision=vllm_config.model_config.tokenizer_revision,
        )
75
76
77
78
79

        # EngineCore (starts the engine in background process).
        self.engine_core = EngineCoreClient.make_client(
            multiprocess_mode=True,
            asyncio_mode=True,
80
81
82
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_stats=self.log_stats,
83
84
        )

85
        self.output_handler: Optional[asyncio.Task] = None
86
87
88
89
90
91
92
93
94
95
96
97

    def __del__(self):
        self.shutdown()

    @classmethod
    def from_engine_args(
        cls,
        engine_args: AsyncEngineArgs,
        engine_config: Optional[VllmConfig] = None,
        start_engine_loop: bool = True,
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
98
    ) -> "AsyncLLM":
99
100
101
102
        """Create an AsyncLLM from the EngineArgs."""

        # Create the engine configs.
        if engine_config is None:
103
            vllm_config = engine_args.create_engine_config(usage_context)
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        else:
            vllm_config = engine_config

        executor_class = cls._get_executor_cls(vllm_config)

        # Create the AsyncLLM.
        return cls(
            vllm_config=vllm_config,
            executor_class=executor_class,
            log_requests=not engine_args.disable_log_requests,
            log_stats=not engine_args.disable_log_stats,
            start_engine_loop=start_engine_loop,
            usage_context=usage_context,
            stat_loggers=stat_loggers,
        )

    def shutdown(self):
        """Shutdown, cleaning up the background proc and IPC."""

123
124
        if engine_core := getattr(self, "engine_core", None):
            engine_core.shutdown()
125
126
127
128
129

        if handler := getattr(self, "output_handler", None):
            handler.cancel()

    @classmethod
130
131
    def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]:
        executor_class: Type[Executor]
132
133
134
135
136
137
138
139
140
141
        distributed_executor_backend = (
            vllm_config.parallel_config.distributed_executor_backend)
        if distributed_executor_backend == "mp":
            from vllm.v1.executor.multiproc_executor import MultiprocExecutor
            executor_class = MultiprocExecutor
        else:
            assert (distributed_executor_backend is None)
            from vllm.v1.executor.uniproc_executor import UniprocExecutor
            executor_class = UniprocExecutor
        return executor_class
142
143
144
145
146
147
148
149
150
151
152

    async def add_request(
        self,
        request_id: str,
        prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
153
    ) -> asyncio.Queue[RequestOutput]:
154
155
        """Add new request to the AsyncLLM."""

156
157
158
159
        # 1) Create a new output queue for the request.
        if request_id in self.rid_to_queue:
            raise ValueError(f"Request id {request_id} already running.")
        self.rid_to_queue[request_id] = asyncio.Queue()
160
161
162
163
164
165
166
167
168
169
170
171

        # 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
        detokenizer_req, engine_core_req = self.processor.process_inputs(
            request_id, prompt, params, arrival_time, lora_request,
            trace_headers, prompt_adapter_request, priority)

        # 3) Add the request to Detokenizer (this process).
        self.detokenizer.add_request(detokenizer_req)

        # 4) Add the EngineCoreRequest to EngineCore (separate process).
        await self.engine_core.add_request_async(engine_core_req)

172
173
174
175
        if self.log_requests:
            logger.info("Added request %s.", request_id)

        return self.rid_to_queue[request_id]
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

    # TODO: we should support multiple prompts in one call, as you
    # can do with LLM.generate. So that for multi-prompt completion
    # requests we don't need to send multiple messages to core proc,
    # and so we don't need multiple streams which then get
    # re-multiplexed in the API server anyhow.
    async def generate(
        self,
        prompt: PromptType,
        sampling_params: SamplingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        priority: int = 0,
    ) -> AsyncGenerator[RequestOutput, None]:
        """
        Main function called by the API server to kick off a request
            * 1) Making an AsyncStream corresponding to the Request.
195
            * 2) Processing the Input.
196
197
198
199
200
201
202
203
204
205
206
            * 3) Adding the Request to the Detokenizer.
            * 4) Adding the Request to the EngineCore (separate process).

        A separate output_handler loop runs in a background AsyncIO task, 
        pulling outputs from EngineCore and putting them into the 
        per-request AsyncStream.

        The caller of generate() iterates the returned AsyncGenerator,
        returning the RequestOutput back to the caller.
        """

207
208
209
210
211
212
213
214
215
        try:
            # We start the output_handler on the first call to generate() so
            # we can call __init__ before the event loop, which enables us
            # to handle startup failure gracefully in the OpenAI server.
            if self.output_handler is None:
                self.output_handler = asyncio.create_task(
                    self._run_output_handler())

            q = await self.add_request(
216
217
218
219
220
221
222
                request_id,
                prompt,
                sampling_params,
                lora_request=lora_request,
                trace_headers=trace_headers,
                prompt_adapter_request=prompt_adapter_request,
                priority=priority,
223
            )
224

225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
            # The output_handler task pushes items into the queue.
            # This task pulls from the queue and yields to caller.
            while True:
                # Note: drain queue without await if possible (avoids
                # task switching under load which helps performance).
                out = q.get_nowait() if q.qsize() > 0 else await q.get()

                # Note: both Detokenizer and EngineCore handle their
                # own request cleanup based on finished.
                if out.finished:
                    del self.rid_to_queue[request_id]
                    yield out
                    break

                yield out

        # If the request is disconnected by the client, the
        # generate() task will be canceled. So, we abort the
        # request if we end up here.
        except asyncio.CancelledError:
            await self.abort(request_id)
            raise
247
248

    def _process_request_outputs(self, request_outputs: List[RequestOutput]):
249
        """Process outputs by putting them into per-request queues."""
250
251
252
253

        for request_output in request_outputs:
            request_id = request_output.request_id

254
255
256
257
258
            # Note: it is possible a request was aborted and removed from
            # the state due to client cancellations, so if we encounter a
            # request id not in the state, we skip.
            if request_id in self.rid_to_queue:
                self.rid_to_queue[request_id].put_nowait(request_output)
259
260
261
262
263
264
265
266
267
268
269
270

    async def _run_output_handler(self):
        """Background loop: pulls from EngineCore and pushes to AsyncStreams."""

        try:
            while True:
                # 1) Pull EngineCoreOutput from the EngineCore.
                outputs = await self.engine_core.get_output_async()

                # 2) Detokenize based on the output.
                request_outputs, reqs_to_abort = self.detokenizer.step(outputs)

271
                # 3) Put the RequestOutputs into the per-request queues.
272
273
274
275
276
277
278
279
280
281
                self._process_request_outputs(request_outputs)

                # 4) Abort any requests that finished due to stop strings.
                await self.engine_core.abort_requests_async(reqs_to_abort)

        except BaseException as e:
            logger.error(e)
            raise e

    async def abort(self, request_id: str) -> None:
282
283
284
285
286
287
288
289
290
291
        """Abort RequestId in self, detokenizer, and engine core."""

        request_ids = [request_id]
        await self.engine_core.abort_requests_async(request_ids)
        self.detokenizer.abort_requests(request_ids)

        # If a request finishes while we await then the request_id
        # will be removed from the tracked queues before we get here.
        if request_id in self.rid_to_queue:
            del self.rid_to_queue[request_id]
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

    def encode(
        self,
        prompt: PromptType,
        pooling_params: PoolingParams,
        request_id: str,
        lora_request: Optional[LoRARequest] = None,
        trace_headers: Optional[Mapping[str, str]] = None,
        priority: int = 0,
    ):
        raise ValueError("Not Supported on V1 yet.")

    async def get_model_config(self) -> ModelConfig:
        return self.model_config

    async def get_decoding_config(self):
        raise ValueError("Not Supported on V1 yet.")

310
311
312
    async def get_input_preprocessor(self) -> InputPreprocessor:
        return self.processor.input_preprocessor

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    async def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        assert lora_request is None
        return self.detokenizer.tokenizer

    async def is_tracing_enabled(self) -> bool:
        return False

    async def do_log_stats(
        self,
        scheduler_outputs=None,
        model_output=None,
    ) -> None:
        logger.debug("Called do_log_stats.")

    async def check_health(self) -> None:
        logger.debug("Called check_health.")

    async def start_profile(self) -> None:
334
        await self.engine_core.profile_async(True)
335
336

    async def stop_profile(self) -> None:
337
        await self.engine_core.profile_async(False)
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352

    @property
    def is_running(self) -> bool:
        return True

    @property
    def is_stopped(self) -> bool:
        return False

    @property
    def errored(self) -> bool:
        return False

    @property
    def dead_error(self) -> BaseException:
353
        return Exception()  # TODO: implement