async_llm_engine.py 12.3 KB
Newer Older
1
2
import asyncio
import time
Antoni Baum's avatar
Antoni Baum committed
3
4
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union
5

6
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10
11
12
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
13
14

logger = init_logger(__name__)
15

Antoni Baum's avatar
Antoni Baum committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

class AsyncStream:
    """A stream of RequestOutputs for a request that can be
    iterated over asynchronously."""

    def __init__(self, request_id: str) -> None:
        self.request_id = request_id
        self._queue = asyncio.Queue()
        self._finished = False

    def put(self, item: RequestOutput) -> None:
        if self._finished:
            return
        self._queue.put_nowait(item)

    def finish(self) -> None:
        self._queue.put_nowait(StopIteration)
        self._finished = True

    @property
    def finished(self) -> bool:
        return self._finished

    def __aiter__(self):
        return self

    async def __anext__(self) -> RequestOutput:
        result = await self._queue.get()
        if result is StopIteration:
            raise StopAsyncIteration
        return result


def _raise_exception_on_finish(task: asyncio.Task) -> None:
    try:
        task.result()
    except Exception as e:
        raise RuntimeError("Task finished unexpectedly.") from e
    raise RuntimeError("Task finished unexpectedly.")


class _AsyncLLMEngine(LLMEngine):
    """Extension of LLMEngine to add async methods."""

    async def step_async(self) -> List[RequestOutput]:
        """Performs one decoding iteration and returns newly generated results.
        The workers are ran asynchronously if possible.

        This function performs one decoding iteration of the engine. It first
        schedules the sequences to be executed in the next iteration and the
        token blocks to be swapped in/out/copy. Then, it executes the model
        and updates the scheduler with the model outputs. Finally, it decodes
        the sequences and returns the newly generated results.
        """
        (seq_group_metadata_list, scheduler_outputs,
         early_return) = self._schedule()
        if early_return is not None:
            return early_return

        # Execute the model.
        output = await self._run_workers_async(
            "execute_model",
            seq_group_metadata_list=seq_group_metadata_list,
            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
            blocks_to_copy=scheduler_outputs.blocks_to_copy,
        )

        return self._process_worker_outputs(output, scheduler_outputs)

    async def _run_workers_async(
        self,
        method: str,
        *args,
        get_all_outputs: bool = False,
        **kwargs,
    ) -> Any:
        """Runs the given method on all workers."""
        all_outputs = []
        for worker in self.workers:
            if self.parallel_config.worker_use_ray:
                executor = partial(worker.execute_method.remote, method)
            else:
                executor = getattr(worker, method)

            output = executor(*args, **kwargs)
            all_outputs.append(output)

        if self.parallel_config.worker_use_ray:
            all_outputs = await asyncio.gather(*all_outputs)

        if get_all_outputs:
            return all_outputs

        # Make sure all workers have the same results.
        output = all_outputs[0]
        for other_output in all_outputs[1:]:
            assert output == other_output
        return output
115
116


117
118
class AsyncLLMEngine:
    """An asynchronous wrapper for LLMEngine.
119

120
    This class is used to wrap the LLMEngine class to make it asynchronous. It
121
    uses asyncio to create a background loop that keeps processing incoming
122
    requests. The LLMEngine is kicked by the generate method when there
123
    are requests in the waiting queue. The generate method yields the outputs
124
    from the LLMEngine to the caller.
125

126
    NOTE: For the comprehensive list of arguments, see `LLMEngine`.
127
128
129
130
131

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
132
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
133
134
            async frontend will be executed in a separate process as the
            model workers.
135
        log_requests: Whether to log the requests.
136
        *args, *kwargs: Arguments for LLMEngine.
137
    """
138

Antoni Baum's avatar
Antoni Baum committed
139
140
    _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine

141
142
143
144
145
    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
Antoni Baum's avatar
Antoni Baum committed
146
                 start_engine_loop: bool = False,
147
                 **kwargs) -> None:
148
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
149
        self.engine_use_ray = engine_use_ray
150
        self.log_requests = log_requests
Antoni Baum's avatar
Antoni Baum committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        self.engine = self._init_engine(*args, **kwargs)

        # Request id -> stream.
        self.request_streams: Dict[str, AsyncStream] = {}
        self.finished_requests: Set[str] = set()
        self.background_loop = None
        if start_engine_loop:
            self._start_background_loop()

    def _start_background_loop(self) -> None:
        """Start the background loop."""
        if self.background_loop is not None:
            raise RuntimeError("Background loop is already running.")
        self.background_loop = asyncio.get_event_loop().create_task(
            self.run_engine_loop())
        self.background_loop.add_done_callback(_raise_exception_on_finish)

    def _init_engine(self, *args,
                     **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
Zhuohan Li's avatar
Zhuohan Li committed
170
        if not self.engine_use_ray:
Antoni Baum's avatar
Antoni Baum committed
171
            engine_class = self._engine_class
172
        elif self.worker_use_ray:
Antoni Baum's avatar
Antoni Baum committed
173
            engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
174
        else:
Antoni Baum's avatar
Antoni Baum committed
175
176
177
178
            engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
        return engine_class(*args, **kwargs)

    async def engine_step(self):
Zhuohan Li's avatar
Zhuohan Li committed
179
180
181
        """Kick the engine to process the waiting requests."""
        if self.engine_use_ray:
            request_outputs = await self.engine.step.remote()
182
        else:
Antoni Baum's avatar
Antoni Baum committed
183
            request_outputs = await self.engine.step_async()
184

Antoni Baum's avatar
Antoni Baum committed
185
        # Put the outputs into the corresponding streams.
186
187
        for request_output in request_outputs:
            request_id = request_output.request_id
Antoni Baum's avatar
Antoni Baum committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
            self.request_streams[request_id].put(request_output)
            if request_output.finished:
                if self.log_requests:
                    logger.info(f"Finished request {request_id}.")
                self.request_streams[request_id].finish()
                self.finished_requests.add(request_id)

        await self._engine_abort(self.finished_requests)
        for request_id in self.finished_requests:
            del self.request_streams[request_id]
        self.finished_requests.clear()

    async def _engine_abort(self, request_ids: Iterable[str]):
        if self.engine_use_ray:
            await self.engine.abort_request.remote(request_ids)
        else:
            self.engine.abort_request(request_ids)

    async def run_engine_loop(self):
        while True:
            await self.engine_step()
            await asyncio.sleep(0)

    async def add_request(
        self,
        request_id: str,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
    ) -> AsyncStream:
        if self.log_requests:
            logger.info(f"Received request {request_id}: "
                        f"prompt: {prompt!r}, "
                        f"sampling params: {sampling_params}, "
                        f"prompt token ids: {prompt_token_ids}.")

        stream = AsyncStream(request_id)
        self.request_streams[request_id] = stream

        # Add the request into the vLLM engine's waiting queue.
        if self.engine_use_ray:
            await self.engine.add_request.remote(
                request_id,
                prompt,
                sampling_params,
                prompt_token_ids=prompt_token_ids,
                arrival_time=arrival_time)
        else:
            self.engine.add_request(request_id,
                                    prompt,
                                    sampling_params,
                                    prompt_token_ids=prompt_token_ids,
                                    arrival_time=arrival_time)

        return stream
244

245
    async def generate(
246
247
248
249
250
            self,
            prompt: Optional[str],
            sampling_params: SamplingParams,
            request_id: str,
            prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
251
252
253
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
254
255
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
256
257
258
259
260
261
262
263
264
265

        Args:
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.

        Yields:
266
            The output `RequestOutput` objects from the LLMEngine for the
267
268
            request.
        """
269
270
271
        # Preprocess the request.
        arrival_time = time.time()

Antoni Baum's avatar
Antoni Baum committed
272
273
274
275
276
277
        try:
            stream = await self.add_request(request_id,
                                            prompt,
                                            sampling_params,
                                            prompt_token_ids=prompt_token_ids,
                                            arrival_time=arrival_time)
278

Antoni Baum's avatar
Antoni Baum committed
279
280
281
282
283
284
            async for request_output in stream:
                yield request_output
        except Exception as e:
            # If there is an exception, abort the request.
            self._abort(request_id)
            raise e
285

Antoni Baum's avatar
Antoni Baum committed
286
287
    async def abort(self, request_id: str) -> None:
        """Abort a request.
288

Antoni Baum's avatar
Antoni Baum committed
289
290
        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.
291

Antoni Baum's avatar
Antoni Baum committed
292
293
294
295
        Args:
            request_id: The unique id of the request.
        """
        return self._abort(request_id)
296

Antoni Baum's avatar
Antoni Baum committed
297
    def _abort(self, request_id: str) -> None:
298
299
300
301
302
303
304
305
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
Antoni Baum's avatar
Antoni Baum committed
306
307
        if request_id not in self.request_streams or self.request_streams[
                request_id].finished:
308
309
310
            # The request has already finished or been aborted.
            return

311
312
        if self.log_requests:
            logger.info(f"Aborted request {request_id}.")
313

Antoni Baum's avatar
Antoni Baum committed
314
315
        self.request_streams[request_id].finish()
        self.finished_requests.add(request_id)
316

317
318
319
320
321
322
323
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_model_config.remote()
        else:
            return self.engine.get_model_config()

Zhuohan Li's avatar
Zhuohan Li committed
324
    @classmethod
325
326
    def from_engine_args(cls,
                         engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
Zhuohan Li's avatar
Zhuohan Li committed
327
328
329
330
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
Zhuohan Li's avatar
Zhuohan Li committed
331
        # Initialize the cluster.
332
        distributed_init_method, placement_group = initialize_cluster(
Zhuohan Li's avatar
Zhuohan Li committed
333
334
335
336
337
            parallel_config, engine_args.engine_use_ray)
        # Create the async LLM engine.
        engine = cls(engine_args.worker_use_ray,
                     engine_args.engine_use_ray,
                     *engine_configs,
338
                     distributed_init_method,
339
                     placement_group,
340
                     log_requests=not engine_args.disable_log_requests,
Zhuohan Li's avatar
Zhuohan Li committed
341
342
                     log_stats=not engine_args.disable_log_stats)
        return engine