async_llm_engine.py 21.5 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
from functools import partial
4
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
5
                    Union, AsyncIterator)
6

7
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
10
11
12
13
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
14
15

logger = init_logger(__name__)
16

Antoni Baum's avatar
Antoni Baum committed
17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class AsyncEngineDeadError(RuntimeError):
    pass


def _raise_exception_on_finish(task: asyncio.Task,
                               request_tracker: "RequestTracker") -> None:
    msg = ("Task finished unexpectedly. This should never happen! "
           "Please open an issue on Github.")
    try:
        try:
            task.result()
        except asyncio.CancelledError:
            return
        except Exception as exc:
            raise AsyncEngineDeadError(
                msg + " See stack trace above for the actual cause.") from exc
        raise AsyncEngineDeadError(msg)
    except Exception as exc:
        request_tracker.propagate_exception(exc)
        raise exc


Antoni Baum's avatar
Antoni Baum committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class AsyncStream:
    """A stream of RequestOutputs for a request that can be
    iterated over asynchronously."""

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
        self._queue = asyncio.Queue()
        self._finished = False

    def put(self, item: RequestOutput) -> None:
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
        self._queue.put_nowait(StopIteration)
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

    async def __anext__(self) -> RequestOutput:
        result = await self._queue.get()
        if result is StopIteration:
            raise StopAsyncIteration
69
70
        elif isinstance(result, Exception):
            raise result
Antoni Baum's avatar
Antoni Baum committed
71
72
73
        return result


74
75
76
77
78
79
80
81
class RequestTracker:
    """Synchronous abstraction for tracking requests."""

    def __init__(self) -> None:
        self._request_streams: Dict[str, AsyncStream] = {}
        self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
        self._new_requests: asyncio.Queue[Tuple[AsyncStream,
                                                dict]] = asyncio.Queue()
82
        self.new_requests_event = None
83
84
85
86

    def __contains__(self, item):
        return item in self._request_streams

87
88
89
90
91
92
93
94
95
96
97
98
99
    def init_event(self):
        self.new_requests_event = asyncio.Event()

    def propagate_exception(self,
                            exc: Exception,
                            request_id: Optional[str] = None) -> None:
        """Propagate an exception to request streams
        (all if request_id is None)."""
        if request_id is not None:
            self._request_streams[request_id].put(exc)
        else:
            for stream in self._request_streams.values():
                stream.put(exc)
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    def process_request_output(self,
                               request_output: RequestOutput,
                               *,
                               verbose: bool = False) -> None:
        """Process a request output from the engine."""
        request_id = request_output.request_id

        self._request_streams[request_id].put(request_output)
        if request_output.finished:
            if verbose:
                logger.info(f"Finished request {request_id}.")
            self.abort_request(request_id)

    def add_request(self, request_id: str,
                    **engine_add_request_kwargs) -> AsyncStream:
        """Add a request to be sent to the engine on the next background
        loop iteration."""
        if request_id in self._request_streams:
            raise KeyError(f"Request {request_id} already exists.")

        stream = AsyncStream(request_id)
        self._new_requests.put_nowait((stream, {
            "request_id": request_id,
            **engine_add_request_kwargs
        }))
126
127
128

        self.new_requests_event.set()

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        return stream

    def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
        """Abort a request during next background loop iteration."""
        if verbose:
            logger.info(f"Aborted request {request_id}.")

        self._finished_requests.put_nowait(request_id)

        if request_id not in self._request_streams or self._request_streams[
                request_id].finished:
            # The request has already finished or been aborted.
            return

        self._request_streams[request_id].finish()

145
    def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
146
147
        """Get the new requests and finished requests to be
        sent to the engine."""
148
        new_requests: List[Dict] = []
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        finished_requests: Set[str] = set()

        while not self._finished_requests.empty():
            request_id = self._finished_requests.get_nowait()
            finished_requests.add(request_id)
            self._request_streams.pop(request_id, None)

        while not self._new_requests.empty():
            stream, new_request = self._new_requests.get_nowait()
            if stream.request_id in finished_requests:
                # The request has already been aborted.
                stream.finish()
                continue
            self._request_streams[stream.request_id] = stream
            new_requests.append(new_request)

165
166
        self.new_requests_event.clear()

167
        return new_requests, finished_requests
Antoni Baum's avatar
Antoni Baum committed
168

169
170
171
    async def wait_for_new_requests(self):
        await self.new_requests_event.wait()

Antoni Baum's avatar
Antoni Baum committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185

class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

    async def step_async(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
186
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
Antoni Baum's avatar
Antoni Baum committed
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        if not scheduler_outputs.is_empty():
            # Execute the model.
            all_outputs = await self._run_workers_async(
                "execute_model",
                driver_kwargs={
                    "seq_group_metadata_list": seq_group_metadata_list,
                    "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
                    "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
                    "blocks_to_copy": scheduler_outputs.blocks_to_copy,
                })

            # Only the driver worker returns the sampling results.
            output = all_outputs[0]
        else:
            output = []
Antoni Baum's avatar
Antoni Baum committed
203

204
        return self._process_model_outputs(output, scheduler_outputs)
Antoni Baum's avatar
Antoni Baum committed
205
206
207
208
209

    async def _run_workers_async(
        self,
        method: str,
        *args,
210
211
        driver_args: Optional[List[Any]] = None,
        driver_kwargs: Optional[Dict[str, Any]] = None,
Antoni Baum's avatar
Antoni Baum committed
212
213
214
        **kwargs,
    ) -> Any:
        """Runs the given method on all workers."""
215
        coros = []
Antoni Baum's avatar
Antoni Baum committed
216

217
218
219
220
        if driver_args is None:
            driver_args = args
        if driver_kwargs is None:
            driver_kwargs = kwargs
Antoni Baum's avatar
Antoni Baum committed
221

222
223
224
225
        # Run the driver worker asynchronously.
        driver_executor = getattr(self.driver_worker, method)
        coros.append(asyncio.get_event_loop().run_in_executor(
            None, partial(driver_executor, *driver_args, **driver_kwargs)))
Antoni Baum's avatar
Antoni Baum committed
226

227
228
229
230
231
232
        # Run the ray workers asynchronously.
        for worker in self.workers:
            coros.append(worker.execute_method.remote(method, *args, **kwargs))

        all_outputs = await asyncio.gather(*coros)
        return all_outputs
233
234


235
236
class AsyncLLMEngine:
    """An asynchronous wrapper for LLMEngine.
237

238
    This class is used to wrap the LLMEngine class to make it asynchronous. It
239
    uses asyncio to create a background loop that keeps processing incoming
240
    requests. The LLMEngine is kicked by the generate method when there
241
    are requests in the waiting queue. The generate method yields the outputs
242
    from the LLMEngine to the caller.
243

244
    NOTE: For the comprehensive list of arguments, see `LLMEngine`.
245
246
247
248
249

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
250
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
251
252
            async frontend will be executed in a separate process as the
            model workers.
253
        log_requests: Whether to log the requests.
254
255
        start_engine_loop: If True, the background task to run the engine
            will be automatically started in the generate call.
256
257
        *args: Arguments for LLMEngine.
        *kwargs: Arguments for LLMEngine.
258
    """
259

Antoni Baum's avatar
Antoni Baum committed
260
261
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

262
263
264
265
266
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
267
                 max_log_len: Optional[int] = None,
268
                 start_engine_loop: bool = True,
269
                 **kwargs) -> None:
270
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
271
        self.engine_use_ray = engine_use_ray
272
        self.log_requests = log_requests
273
        self.max_log_len = max_log_len
Antoni Baum's avatar
Antoni Baum committed
274
275
276
        self.engine = self._init_engine(*args, **kwargs)

        self.background_loop = None
277
278
279
280
        # We need to keep a reference to unshielded
        # task as well to prevent it from being garbage
        # collected
        self._background_loop_unshielded = None
281
        self.start_engine_loop = start_engine_loop
282
        self._request_tracker = RequestTracker()
Antoni Baum's avatar
Antoni Baum committed
283

284
285
    @property
    def is_running(self) -> bool:
286
287
        return (self.background_loop is not None
                and not self.background_loop.done())
288
289

    def start_background_loop(self) -> None:
Antoni Baum's avatar
Antoni Baum committed
290
        """Start the background loop."""
291
        if self.is_running:
Antoni Baum's avatar
Antoni Baum committed
292
            raise RuntimeError("Background loop is already running.")
293
294
295
296
297
        self._request_tracker.init_event()

        self._background_loop_unshielded = asyncio.get_event_loop(
        ).create_task(self.run_engine_loop())
        self._background_loop_unshielded.add_done_callback(
298
            partial(_raise_exception_on_finish,
299
300
                    request_tracker=self._request_tracker))
        self.background_loop = asyncio.shield(self._background_loop_unshielded)
Antoni Baum's avatar
Antoni Baum committed
301
302
303

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
304
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
305
            engine_class = self._engine_class
306
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
307
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
308
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
309
310
311
312
313
314
315
316
317
318
            # FIXME(woosuk): This is a bit hacky. Be careful when changing the
            # order of the arguments.
            cache_config = args[1]
            parallel_config = args[2]
            if parallel_config.tensor_parallel_size == 1:
                num_gpus = cache_config.gpu_memory_utilization
            else:
                num_gpus = 1
            engine_class = ray.remote(num_gpus=num_gpus)(
                self._engine_class).remote
Antoni Baum's avatar
Antoni Baum committed
319
320
        return engine_class(*args, **kwargs)

321
322
323
324
    async def engine_step(self) -> bool:
        """Kick the engine to process the waiting requests.

        Returns True if there are in-progress requests."""
325
326

        new_requests, finished_requests = (
327
            self._request_tracker.get_new_and_finished_requests())
328
329
330
331
332
333
334
335
336
337
338
339

        for new_request in new_requests:
            # Add the request into the vLLM engine's waiting queue.
            # TODO: Maybe add add_request_batch to reduce Ray overhead
            if self.engine_use_ray:
                await self.engine.add_request.remote(**new_request)
            else:
                self.engine.add_request(**new_request)

        if finished_requests:
            await self._engine_abort(finished_requests)

Zhuohan Li's avatar
Zhuohan Li committed
340
341
        if self.engine_use_ray:
            request_outputs = await self.engine.step.remote()
342
        else:
Antoni Baum's avatar
Antoni Baum committed
343
            request_outputs = await self.engine.step_async()
344

Antoni Baum's avatar
Antoni Baum committed
345
        # Put the outputs into the corresponding streams.
346
        for request_output in request_outputs:
347
            self._request_tracker.process_request_output(
348
                request_output, verbose=self.log_requests)
Antoni Baum's avatar
Antoni Baum committed
349

350
351
        return len(request_outputs) > 0

Antoni Baum's avatar
Antoni Baum committed
352
353
354
355
356
357
358
    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
            await self.engine.abort_request.remote(request_ids)
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
359
360
        # Initialize the RequestTracker here so it uses the right event loop.
        has_requests_in_progress = False
Antoni Baum's avatar
Antoni Baum committed
361
        while True:
362
363
364
            if not has_requests_in_progress:
                await self._request_tracker.wait_for_new_requests()
            has_requests_in_progress = await self.engine_step()
Antoni Baum's avatar
Antoni Baum committed
365
366
367
368
369
370
371
372
373
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
374
        prefix_pos: Optional[int] = None,
Antoni Baum's avatar
Antoni Baum committed
375
376
    ) -> AsyncStream:
        if self.log_requests:
377
378
379
380
381
382
383
384
            shortened_prompt = prompt
            shortened_token_ids = prompt_token_ids
            if self.max_log_len is not None:
                if shortened_prompt is not None:
                    shortened_prompt = shortened_prompt[:self.max_log_len]
                if shortened_token_ids is not None:
                    shortened_token_ids = shortened_token_ids[:self.
                                                              max_log_len]
Antoni Baum's avatar
Antoni Baum committed
385
            logger.info(f"Received request {request_id}: "
386
                        f"prompt: {shortened_prompt!r}, "
387
                        f"prefix_pos: {prefix_pos},"
Antoni Baum's avatar
Antoni Baum committed
388
                        f"sampling params: {sampling_params}, "
389
                        f"prompt token ids: {shortened_token_ids}.")
Antoni Baum's avatar
Antoni Baum committed
390

391
        if not self.is_running:
392
393
394
395
396
397
398
399
            if self.start_engine_loop:
                self.start_background_loop()
            else:
                raise AsyncEngineDeadError(
                    "Background loop is not running. If it was running, "
                    "inspect the output to find the stacktrace of the "
                    "error that caused the background loop to stop "
                    "(AsyncEngineDeadError).")
Antoni Baum's avatar
Antoni Baum committed
400

401
        stream = self._request_tracker.add_request(
402
403
404
405
            request_id,
            prompt=prompt,
            sampling_params=sampling_params,
            prompt_token_ids=prompt_token_ids,
406
407
            arrival_time=arrival_time,
            prefix_pos=prefix_pos)
Antoni Baum's avatar
Antoni Baum committed
408
409

        return stream
410

411
    async def generate(
412
413
414
415
        self,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        request_id: str,
416
417
        prompt_token_ids: Optional[List[int]] = None,
        prefix_pos: Optional[int] = None,
418
    ) -> AsyncIterator[RequestOutput]:
419
420
421
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
422
423
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
424
425
426
427
428
429
430
431

        Args:
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.
432
433
434
435
436
            prefix_pos: If not None, we use the given position as the prefix
                position for each prompt. We will cache the prefix's KV
                cache and reuse it for the next request with the same prefix.
                This is an experimental feature, and may be replaced with
                automatic prefix caching in the future.
437
438

        Yields:
439
            The output `RequestOutput` objects from the LLMEngine for the
440
            request.
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483

        Details:
            - If the engine is not running, start the background loop,
              which iteratively invokes
              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
              to process the waiting requests.
            - Add the request to the engine's `RequestTracker`.
              On the next background loop, this request will be sent to
              the underlying engine.
              Also, a corresponding `AsyncStream` will be created.
            - Wait for the request outputs from `AsyncStream` and yield them.

        Example:
            >>> # Please refer to entrypoints/api_server.py for
            >>> # the complete example.
            >>>
            >>> # initialize the engine and the example input
            >>> engine = AsyncLLMEngine.from_engine_args(engine_args)
            >>> example_input = {
            >>>     "prompt": "What is LLM?",
            >>>     "stream": False, # assume the non-streaming case
            >>>     "temperature": 0.0,
            >>>     "request_id": 0,
            >>> }
            >>>
            >>> # start the generation
            >>> results_generator = engine.generate(
            >>>    example_input["prompt"],
            >>>    SamplingParams(temperature=example_input["temperature"]),
            >>>    example_input["request_id"])
            >>>
            >>> # get the results
            >>> final_output = None
            >>> async for request_output in results_generator:
            >>>     if await request.is_disconnected():
            >>>         # Abort the request if the client disconnects.
            >>>         await engine.abort(request_id)
            >>>         # Return or raise an error
            >>>         ...
            >>>     final_output = request_output
            >>>
            >>> # Process and return the final output
            >>> ...
484
        """
485
        # Preprocess the request.
486
487
        # This should not be used for logging, as it is monotonic time.
        arrival_time = time.monotonic()
488

Antoni Baum's avatar
Antoni Baum committed
489
490
491
492
493
        try:
            stream = await self.add_request(request_id,
                                            prompt,
                                            sampling_params,
                                            prompt_token_ids=prompt_token_ids,
494
495
                                            arrival_time=arrival_time,
                                            prefix_pos=prefix_pos)
496

Antoni Baum's avatar
Antoni Baum committed
497
498
            async for request_output in stream:
                yield request_output
499
500
501
        except (Exception, asyncio.CancelledError) as e:
            # If there is an exception or coroutine is cancelled, abort the
            # request.
Antoni Baum's avatar
Antoni Baum committed
502
503
            self._abort(request_id)
            raise e
504

Antoni Baum's avatar
Antoni Baum committed
505
506
    async def abort(self, request_id: str) -> None:
        """Abort a request.
507

Antoni Baum's avatar
Antoni Baum committed
508
509
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
510

Antoni Baum's avatar
Antoni Baum committed
511
512
513
        Args:
            request_id: The unique id of the request.
        """
514
515
516
517
518
519
520
        if not self.is_running:
            raise AsyncEngineDeadError(
                "Background loop is not running. If it was running, "
                "inspect the output to find the stacktrace of the "
                "error that caused the background loop to stop "
                "(AsyncEngineDeadError).")

Antoni Baum's avatar
Antoni Baum committed
521
        return self._abort(request_id)
522

Antoni Baum's avatar
Antoni Baum committed
523
    def _abort(self, request_id: str) -> None:
524
525
526
527
528
529
530
531
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
532
533
        self._request_tracker.abort_request(request_id,
                                            verbose=self.log_requests)
534

535
536
537
538
539
540
541
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_model_config.remote()
        else:
            return self.engine.get_model_config()

Zhuohan Li's avatar
Zhuohan Li committed
542
    @classmethod
543
    def from_engine_args(cls,
544
                         engine_args: AsyncEngineArgs,
545
                         start_engine_loop: bool = True) -> "AsyncLLMEngine":
Zhuohan Li's avatar
Zhuohan Li committed
546
547
548
549
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
Zhuohan Li's avatar
Zhuohan Li committed
550
        # Initialize the cluster.
551
552
        placement_group = initialize_cluster(parallel_config,
                                             engine_args.engine_use_ray)
Zhuohan Li's avatar
Zhuohan Li committed
553
        # Create the async LLM engine.
554
        engine = cls(parallel_config.worker_use_ray,
Zhuohan Li's avatar
Zhuohan Li committed
555
556
                     engine_args.engine_use_ray,
                     *engine_configs,
557
                     placement_group,
558
                     log_requests=not engine_args.disable_log_requests,
559
                     log_stats=not engine_args.disable_log_stats,
560
                     max_log_len=engine_args.max_log_len,
561
                     start_engine_loop=start_engine_loop)
Zhuohan Li's avatar
Zhuohan Li committed
562
        return engine
563
564
565
566
567
568

    async def do_log_stats(self) -> None:
        if self.engine_use_ray:
            await self.engine.do_log_stats.remote()
        else:
            self.engine.do_log_stats()