async_llm_engine.py 9.29 KB
Newer Older
1
2
import asyncio
import time
3
from typing import Dict, List, Optional
4

5
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
8
9
10
11
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
12
13

logger = init_logger(__name__)
14

15
TIMEOUT_TO_PREVENT_DEADLOCK = 1  # seconds
16
17


18
19
class AsyncLLMEngine:
    """An asynchronous wrapper for LLMEngine.
20

21
    This class is used to wrap the LLMEngine class to make it asynchronous. It
22
    uses asyncio to create a background loop that keeps processing incoming
23
    requests. The LLMEngine is kicked by the generate method when there
24
    are requests in the waiting queue. The generate method yields the outputs
25
    from the LLMEngine to the caller.
26

27
    NOTE: For the comprehensive list of arguments, see `LLMEngine`.
28
29
30
31
32

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
33
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
34
35
            async frontend will be executed in a separate process as the
            model workers.
36
        log_requests: Whether to log the requests.
37
        *args, *kwargs: Arguments for LLMEngine.
38
    """
39
40
41
42
43
44
45

    def __init__(self,
                 worker_use_ray: bool,
                 engine_use_ray: bool,
                 *args,
                 log_requests: bool = True,
                 **kwargs) -> None:
46
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
47
        self.engine_use_ray = engine_use_ray
48
        self.log_requests = log_requests
Zhuohan Li's avatar
Zhuohan Li committed
49
50
        if not self.engine_use_ray:
            engine_class = LLMEngine
51
        elif self.worker_use_ray:
Zhuohan Li's avatar
Zhuohan Li committed
52
            engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
53
        else:
Zhuohan Li's avatar
Zhuohan Li committed
54
55
            engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
        self.engine = engine_class(*args, **kwargs)
56
57
58
59
        # Request id -> request output.
        self.request_outputs: Dict[str, RequestOutput] = {}
        # Request id -> event to notify that there is new output.
        self.request_events: Dict[str, asyncio.Event] = {}
Zhuohan Li's avatar
Zhuohan Li committed
60
        self.is_engine_running = False
61
        self.kicking_request_id: Optional[str] = None
62

Zhuohan Li's avatar
Zhuohan Li committed
63
64
65
    async def engine_step(self, kicking_request_id: Optional[str] = None):
        """Kick the engine to process the waiting requests."""
        self.is_engine_running = True
66
        self.kicking_request_id = kicking_request_id
Zhuohan Li's avatar
Zhuohan Li committed
67
68
        if self.engine_use_ray:
            request_outputs = await self.engine.step.remote()
69
70
        else:
            # Yield to the event loop to allow other coroutines to run
Zhuohan Li's avatar
Zhuohan Li committed
71
            # while is_engine_running is True. This let the engine to add new
72
73
            # requests into the queue.
            await asyncio.sleep(0)
Zhuohan Li's avatar
Zhuohan Li committed
74
75
            request_outputs = self.engine.step()
        self.is_engine_running = False
76
77
        self.kicking_request_id = None

78
79
80
81
82
83
        # Notify the waiting coroutines that there are new outputs ready.
        for request_output in request_outputs:
            request_id = request_output.request_id
            self.request_outputs[request_id] = request_output
            self.request_events[request_id].set()

84
    async def generate(
85
86
87
88
89
            self,
            prompt: Optional[str],
            sampling_params: SamplingParams,
            request_id: str,
            prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
90
91
92
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
93
94
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
95
96
97
98
99
100
101
102
103
104

        Args:
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.

        Yields:
105
            The output `RequestOutput` objects from the LLMEngine for the
106
107
            request.
        """
108
109
110
111
        # Preprocess the request.
        arrival_time = time.time()

        # Create an event to notify us that there is new output from the
Woosuk Kwon's avatar
Woosuk Kwon committed
112
        # vLLM engine.
113
114
115
        request_event = asyncio.Event()
        self.request_events[request_id] = request_event

116
117
118
119
120
        if self.log_requests:
            logger.info(f"Received request {request_id}: "
                        f"prompt: {prompt!r}, "
                        f"sampling params: {sampling_params}, "
                        f"prompt token ids: {prompt_token_ids}.")
121

Woosuk Kwon's avatar
Woosuk Kwon committed
122
        # Add the request into the vLLM engine's waiting queue.
Zhuohan Li's avatar
Zhuohan Li committed
123
124
        if self.engine_use_ray:
            await self.engine.add_request.remote(
125
126
127
                request_id,
                prompt,
                sampling_params,
128
129
                prompt_token_ids=prompt_token_ids,
                arrival_time=arrival_time)
130
        else:
131
132
133
134
135
            self.engine.add_request(request_id,
                                    prompt,
                                    sampling_params,
                                    prompt_token_ids=prompt_token_ids,
                                    arrival_time=arrival_time)
136

Woosuk Kwon's avatar
Woosuk Kwon committed
137
        # The vLLM engine does not have a background loop that keeps
138
        # processing incoming requests. Therefore, we need to keep kicking
Zhuohan Li's avatar
Zhuohan Li committed
139
        # the engine to process the requests.
140
        while True:
141
142
143
144
            if request_id not in self.request_events:
                # The request has been aborted.
                return

Zhuohan Li's avatar
Zhuohan Li committed
145
146
147
            # Kick the engine if the engine is not running.
            if not self.is_engine_running:
                await self.engine_step(request_id)
148

Zhuohan Li's avatar
Zhuohan Li committed
149
            # Wait for new output. The group_event will be set in engine_step
150
151
152
153
154
155
156
157
158
159
160
161
            # when there is new output available for the sequence group.
            # Added a timeout to prevent deadlock.
            try:
                await asyncio.wait_for(request_event.wait(),
                                       timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
            except asyncio.TimeoutError:
                continue
            # Reset the event to wait for the next output.
            request_event.clear()

            # Decode and return new outputs.
            request_output = self.request_outputs[request_id]
Zhuohan Li's avatar
Zhuohan Li committed
162
            yield request_output
163
164

            # Once finished, release the resources of the sequence group.
165
            if request_output.finished:
166
167
                if self.log_requests:
                    logger.info(f"Finished request {request_id}.")
168

169
170
                del self.request_outputs[request_id]
                del self.request_events[request_id]
Zhuohan Li's avatar
Zhuohan Li committed
171
172
                # Kick the engine if the engine is not running. This is to
                # prevent that there are still requests in engine's waiting
173
                # queue to be executed.
Zhuohan Li's avatar
Zhuohan Li committed
174
175
                if not self.is_engine_running:
                    await self.engine_step()
176
177
                break

178
    async def abort(self, request_id: str) -> None:
179
180
181
182
183
184
185
186
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
187
188
189
190
        if request_id not in self.request_events:
            # The request has already finished or been aborted.
            return

191
192
        if self.log_requests:
            logger.info(f"Aborted request {request_id}.")
193

Zhuohan Li's avatar
Zhuohan Li committed
194
195
        if self.engine_use_ray:
            await self.engine.abort_request.remote(request_id)
196
        else:
Zhuohan Li's avatar
Zhuohan Li committed
197
            self.engine.abort_request(request_id)
198
199
200
201
202
203

        if request_id in self.request_events:
            del self.request_events[request_id]
        if request_id in self.request_outputs:
            del self.request_outputs[request_id]

Zhuohan Li's avatar
Zhuohan Li committed
204
        # To prevent deadlock when a request is aborted while the engine is
205
206
        # running.
        if self.kicking_request_id == request_id:
Zhuohan Li's avatar
Zhuohan Li committed
207
            self.is_engine_running = False
208
209
            self.kicking_request_id = None

210
211
212
213
214
215
216
    async def get_model_config(self) -> ModelConfig:
        """Get the model configuration of the vLLM engine."""
        if self.engine_use_ray:
            return await self.engine.get_model_config.remote()
        else:
            return self.engine.get_model_config()

Zhuohan Li's avatar
Zhuohan Li committed
217
    @classmethod
218
219
    def from_engine_args(cls,
                         engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
Zhuohan Li's avatar
Zhuohan Li committed
220
221
222
223
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
Zhuohan Li's avatar
Zhuohan Li committed
224
        # Initialize the cluster.
225
        distributed_init_method, devices = initialize_cluster(
Zhuohan Li's avatar
Zhuohan Li committed
226
227
228
229
230
            parallel_config, engine_args.engine_use_ray)
        # Create the async LLM engine.
        engine = cls(engine_args.worker_use_ray,
                     engine_args.engine_use_ray,
                     *engine_configs,
231
232
233
                     distributed_init_method,
                     devices,
                     log_requests=not engine_args.disable_log_requests,
Zhuohan Li's avatar
Zhuohan Li committed
234
235
                     log_stats=not engine_args.disable_log_stats)
        return engine