async_llm_engine.py 8.98 KB
Newer Older
1
2
import asyncio
import time
3
from typing import Dict, List, Optional
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
11
12

logger = init_logger(__name__)
13

14
TIMEOUT_TO_PREVENT_DEADLOCK = 1  # seconds
15
16


17
18
class AsyncLLMEngine:
    """An asynchronous wrapper for LLMEngine.
19

20
    This class is used to wrap the LLMEngine class to make it asynchronous. It
21
    uses asyncio to create a background loop that keeps processing incoming
22
    requests. The LLMEngine is kicked by the generate method when there
23
    are requests in the waiting queue. The generate method yields the outputs
24
    from the LLMEngine to the caller.
25

26
    NOTE: For the comprehensive list of arguments, see `LLMEngine`.
27
28
29
30
31

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
32
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
33
34
            async frontend will be executed in a separate process as the
            model workers.
35
        log_requests: Whether to log the requests.
36
        *args, *kwargs: Arguments for LLMEngine.
37
    """
38
39
40
41
42
43
44

    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
                 **kwargs) -> None:
45
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
46
        self.engine_use_ray = engine_use_ray
47
        self.log_requests = log_requests
Zhuohan Li's avatar
Zhuohan Li committed
48
49
        if not self.engine_use_ray:
            engine_class = LLMEngine
50
        elif self.worker_use_ray:
Zhuohan Li's avatar
Zhuohan Li committed
51
            engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
52
        else:
Zhuohan Li's avatar
Zhuohan Li committed
53
54
            engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
        self.engine = engine_class(*args, **kwargs)
55
56
57
58
        # Request id -> request output.
        self.request_outputs: Dict[str, RequestOutput] = {}
        # Request id -> event to notify that there is new output.
        self.request_events: Dict[str, asyncio.Event] = {}
Zhuohan Li's avatar
Zhuohan Li committed
59
        self.is_engine_running = False
60
        self.kicking_request_id: Optional[str] = None
61

Zhuohan Li's avatar
Zhuohan Li committed
62
63
64
    async def engine_step(self, kicking_request_id: Optional[str] = None):
        """Kick the engine to process the waiting requests."""
        self.is_engine_running = True
65
        self.kicking_request_id = kicking_request_id
Zhuohan Li's avatar
Zhuohan Li committed
66
67
        if self.engine_use_ray:
            request_outputs = await self.engine.step.remote()
68
69
        else:
            # Yield to the event loop to allow other coroutines to run
Zhuohan Li's avatar
Zhuohan Li committed
70
            # while is_engine_running is True. This let the engine to add new
71
72
            # requests into the queue.
            await asyncio.sleep(0)
Zhuohan Li's avatar
Zhuohan Li committed
73
74
            request_outputs = self.engine.step()
        self.is_engine_running = False
75
76
        self.kicking_request_id = None

77
78
79
80
81
82
        # Notify the waiting coroutines that there are new outputs ready.
        for request_output in request_outputs:
            request_id = request_output.request_id
            self.request_outputs[request_id] = request_output
            self.request_events[request_id].set()

83
    async def generate(
84
85
86
87
88
            self,
            prompt: Optional[str],
            sampling_params: SamplingParams,
            request_id: str,
            prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
89
90
91
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
92
93
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
94
95
96
97
98
99
100
101
102
103

        Args:
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.

        Yields:
104
            The output `RequestOutput` objects from the LLMEngine for the
105
106
            request.
        """
107
108
109
110
        # Preprocess the request.
        arrival_time = time.time()

        # Create an event to notify us that there is new output from the
Woosuk Kwon's avatar
Woosuk Kwon committed
111
        # vLLM engine.
112
113
114
        request_event = asyncio.Event()
        self.request_events[request_id] = request_event

115
116
117
118
119
        if self.log_requests:
            logger.info(f"Received request {request_id}: "
                        f"prompt: {prompt!r}, "
                        f"sampling params: {sampling_params}, "
                        f"prompt token ids: {prompt_token_ids}.")
120

Woosuk Kwon's avatar
Woosuk Kwon committed
121
        # Add the request into the vLLM engine's waiting queue.
Zhuohan Li's avatar
Zhuohan Li committed
122
123
        if self.engine_use_ray:
            await self.engine.add_request.remote(
124
125
126
                request_id,
                prompt,
                sampling_params,
127
128
                prompt_token_ids=prompt_token_ids,
                arrival_time=arrival_time)
129
        else:
130
131
132
133
134
            self.engine.add_request(request_id,
                                    prompt,
                                    sampling_params,
                                    prompt_token_ids=prompt_token_ids,
                                    arrival_time=arrival_time)
135

Woosuk Kwon's avatar
Woosuk Kwon committed
136
        # The vLLM engine does not have a background loop that keeps
137
        # processing incoming requests. Therefore, we need to keep kicking
Zhuohan Li's avatar
Zhuohan Li committed
138
        # the engine to process the requests.
139
        while True:
140
141
142
143
            if request_id not in self.request_events:
                # The request has been aborted.
                return

Zhuohan Li's avatar
Zhuohan Li committed
144
145
146
            # Kick the engine if the engine is not running.
            if not self.is_engine_running:
                await self.engine_step(request_id)
147

Zhuohan Li's avatar
Zhuohan Li committed
148
            # Wait for new output. The group_event will be set in engine_step
149
150
151
152
153
154
155
156
157
158
159
160
            # when there is new output available for the sequence group.
            # Added a timeout to prevent deadlock.
            try:
                await asyncio.wait_for(request_event.wait(),
                                       timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
            except asyncio.TimeoutError:
                continue
            # Reset the event to wait for the next output.
            request_event.clear()

            # Decode and return new outputs.
            request_output = self.request_outputs[request_id]
Zhuohan Li's avatar
Zhuohan Li committed
161
            yield request_output
162
163

            # Once finished, release the resources of the sequence group.
164
            if request_output.finished:
165
166
                if self.log_requests:
                    logger.info(f"Finished request {request_id}.")
167

168
169
                del self.request_outputs[request_id]
                del self.request_events[request_id]
Zhuohan Li's avatar
Zhuohan Li committed
170
171
                # Kick the engine if the engine is not running. This is to
                # prevent that there are still requests in engine's waiting
172
                # queue to be executed.
Zhuohan Li's avatar
Zhuohan Li committed
173
174
                if not self.is_engine_running:
                    await self.engine_step()
175
176
                break

177
    async def abort(self, request_id: str) -> None:
178
179
180
181
182
183
184
185
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
186
187
188
189
        if request_id not in self.request_events:
            # The request has already finished or been aborted.
            return

190
191
        if self.log_requests:
            logger.info(f"Aborted request {request_id}.")
192

Zhuohan Li's avatar
Zhuohan Li committed
193
194
        if self.engine_use_ray:
            await self.engine.abort_request.remote(request_id)
195
        else:
Zhuohan Li's avatar
Zhuohan Li committed
196
            self.engine.abort_request(request_id)
197
198
199
200
201
202

        if request_id in self.request_events:
            del self.request_events[request_id]
        if request_id in self.request_outputs:
            del self.request_outputs[request_id]

Zhuohan Li's avatar
Zhuohan Li committed
203
        # To prevent deadlock when a request is aborted while the engine is
204
205
        # running.
        if self.kicking_request_id == request_id:
Zhuohan Li's avatar
Zhuohan Li committed
206
            self.is_engine_running = False
207
208
            self.kicking_request_id = None

Zhuohan Li's avatar
Zhuohan Li committed
209
    @classmethod
210
211
    def from_engine_args(cls,
                         engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
Zhuohan Li's avatar
Zhuohan Li committed
212
213
214
215
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
Zhuohan Li's avatar
Zhuohan Li committed
216
        # Initialize the cluster.
217
        distributed_init_method, devices = initialize_cluster(
Zhuohan Li's avatar
Zhuohan Li committed
218
219
220
221
222
            parallel_config, engine_args.engine_use_ray)
        # Create the async LLM engine.
        engine = cls(engine_args.worker_use_ray,
                     engine_args.engine_use_ray,
                     *engine_configs,
223
224
225
                     distributed_init_method,
                     devices,
                     log_requests=not engine_args.disable_log_requests,
Zhuohan Li's avatar
Zhuohan Li committed
226
227
                     log_stats=not engine_args.disable_log_stats)
        return engine