async_llm_engine.py 8.76 KB
Newer Older
1
2
import asyncio
import time
3
from typing import Dict, List, Optional
4

Zhuohan Li's avatar
Zhuohan Li committed
5
6
7
from cacheflow.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.llm_engine import LLMEngine
from cacheflow.engine.ray_utils import initialize_cluster, ray
8
from cacheflow.logger import init_logger
9
10
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
11
12

logger = init_logger(__name__)
13
14
15
16

TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds


17
18
class AsyncLLMEngine:
    """An asynchronous wrapper for LLMEngine.
19

20
    This class is used to wrap the LLMEngine class to make it asynchronous. It
21
    uses asyncio to create a background loop that keeps processing incoming
22
    requests. The LLMEngine is kicked by the generate method when there
23
    are requests in the waiting queue. The generate method yields the outputs
24
    from the LLMEngine to the caller.
25

26
    NOTE: For the comprehensive list of arguments, see `LLMEngine`.
27
28
29
30
31

    Args:
        worker_use_ray: Whether to use Ray for model workers. Required for
            distributed execution. Should be the same as
            `parallel_config.worker_use_ray`.
Zhuohan Li's avatar
Zhuohan Li committed
32
        engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
33
34
            async frontend will be executed in a separate process as the
            model workers.
35
        log_requests: Whether to log the requests.
36
        *args, *kwargs: Arguments for LLMEngine.
37
    """
Zhuohan Li's avatar
Zhuohan Li committed
38
    def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
39
                 log_requests: bool = True, *args, **kwargs) -> None:
40
        self.worker_use_ray = worker_use_ray
Zhuohan Li's avatar
Zhuohan Li committed
41
        self.engine_use_ray = engine_use_ray
42
        self.log_requests = log_requests
Zhuohan Li's avatar
Zhuohan Li committed
43
44
        if not self.engine_use_ray:
            engine_class = LLMEngine
45
        elif self.worker_use_ray:
Zhuohan Li's avatar
Zhuohan Li committed
46
            engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
47
        else:
Zhuohan Li's avatar
Zhuohan Li committed
48
49
            engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
        self.engine = engine_class(*args, **kwargs)
50
51
52
53
        # Request id -> request output.
        self.request_outputs: Dict[str, RequestOutput] = {}
        # Request id -> event to notify that there is new output.
        self.request_events: Dict[str, asyncio.Event] = {}
Zhuohan Li's avatar
Zhuohan Li committed
54
        self.is_engine_running = False
55
        self.kicking_request_id: Optional[str] = None
56

Zhuohan Li's avatar
Zhuohan Li committed
57
58
59
    async def engine_step(self, kicking_request_id: Optional[str] = None):
        """Kick the engine to process the waiting requests."""
        self.is_engine_running = True
60
        self.kicking_request_id = kicking_request_id
Zhuohan Li's avatar
Zhuohan Li committed
61
62
        if self.engine_use_ray:
            request_outputs = await self.engine.step.remote()
63
64
        else:
            # Yield to the event loop to allow other coroutines to run
Zhuohan Li's avatar
Zhuohan Li committed
65
            # while is_engine_running is True. This let the engine to add new
66
67
            # requests into the queue.
            await asyncio.sleep(0)
Zhuohan Li's avatar
Zhuohan Li committed
68
69
            request_outputs = self.engine.step()
        self.is_engine_running = False
70
71
        self.kicking_request_id = None

72
73
74
75
76
77
        # Notify the waiting coroutines that there are new outputs ready.
        for request_output in request_outputs:
            request_id = request_output.request_id
            self.request_outputs[request_id] = request_output
            self.request_events[request_id].set()

78
79
80
81
82
83
84
85
86
87
    async def generate(
        self,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        request_id: str,
        prompt_token_ids: Optional[List[int]] = None
    ) -> RequestOutput:
        """Generate outputs for a request.

        Generate outputs for a request. This method is a coroutine. It adds the
88
89
        request into the waiting queue of the LLMEngine and streams the outputs
        from the LLMEngine to the caller.
90
91
92
93
94
95
96
97
98
99

        Args:
            prompt: The prompt string. Can be None if prompt_token_ids is
                provided.
            sampling_params: The sampling parameters of the request.
            request_id: The unique id of the request.
            prompt_token_ids: The token IDs of the prompt. If None, we
                use the tokenizer to convert the prompts to token IDs.

        Yields:
100
            The output `RequestOutput` objects from the LLMEngine for the
101
102
            request.
        """
103
104
105
106
        # Preprocess the request.
        arrival_time = time.time()

        # Create an event to notify us that there is new output from the
Zhuohan Li's avatar
Zhuohan Li committed
107
        # cacheflow engine.
108
109
110
        request_event = asyncio.Event()
        self.request_events[request_id] = request_event

111
112
113
114
115
        if self.log_requests:
            logger.info(f"Received request {request_id}: "
                        f"prompt: {prompt!r}, "
                        f"sampling params: {sampling_params}, "
                        f"prompt token ids: {prompt_token_ids}.")
116

Zhuohan Li's avatar
Zhuohan Li committed
117
118
119
        # Add the request into the cacheflow engine's waiting queue.
        if self.engine_use_ray:
            await self.engine.add_request.remote(
120
121
122
                request_id, prompt, sampling_params,
                prompt_token_ids=prompt_token_ids,
                arrival_time=arrival_time)
123
        else:
Zhuohan Li's avatar
Zhuohan Li committed
124
            self.engine.add_request(
125
126
127
                request_id, prompt, sampling_params,
                prompt_token_ids=prompt_token_ids,
                arrival_time=arrival_time)
128

Zhuohan Li's avatar
Zhuohan Li committed
129
        # The cacheflow engine does not have a background loop that keeps
130
        # processing incoming requests. Therefore, we need to keep kicking
Zhuohan Li's avatar
Zhuohan Li committed
131
        # the engine to process the requests.
132
        while True:
133
134
135
136
            if request_id not in self.request_events:
                # The request has been aborted.
                return

Zhuohan Li's avatar
Zhuohan Li committed
137
138
139
            # Kick the engine if the engine is not running.
            if not self.is_engine_running:
                await self.engine_step(request_id)
140

Zhuohan Li's avatar
Zhuohan Li committed
141
            # Wait for new output. The group_event will be set in engine_step
142
143
144
145
146
147
148
149
150
151
152
153
            # when there is new output available for the sequence group.
            # Added a timeout to prevent deadlock.
            try:
                await asyncio.wait_for(request_event.wait(),
                                       timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
            except asyncio.TimeoutError:
                continue
            # Reset the event to wait for the next output.
            request_event.clear()

            # Decode and return new outputs.
            request_output = self.request_outputs[request_id]
Zhuohan Li's avatar
Zhuohan Li committed
154
            yield request_output
155
156

            # Once finished, release the resources of the sequence group.
Zhuohan Li's avatar
Zhuohan Li committed
157
            if request_output.finished():
158
159
                if self.log_requests:
                    logger.info(f"Finished request {request_id}.")
160

161
162
                del self.request_outputs[request_id]
                del self.request_events[request_id]
Zhuohan Li's avatar
Zhuohan Li committed
163
164
                # Kick the engine if the engine is not running. This is to
                # prevent that there are still requests in engine's waiting
165
                # queue to be executed.
Zhuohan Li's avatar
Zhuohan Li committed
166
167
                if not self.is_engine_running:
                    await self.engine_step()
168
169
                break

170
    async def abort(self, request_id: str) -> None:
171
172
173
174
175
176
177
178
        """Abort a request.

        Abort a submitted request. If the request is finished or not found,
        this method will be a no-op.

        Args:
            request_id: The unique id of the request.
        """
179
180
181
182
        if request_id not in self.request_events:
            # The request has already finished or been aborted.
            return

183
184
        if self.log_requests:
            logger.info(f"Aborted request {request_id}.")
185

Zhuohan Li's avatar
Zhuohan Li committed
186
187
        if self.engine_use_ray:
            await self.engine.abort_request.remote(request_id)
188
        else:
Zhuohan Li's avatar
Zhuohan Li committed
189
            self.engine.abort_request(request_id)
190
191
192
193
194
195

        if request_id in self.request_events:
            del self.request_events[request_id]
        if request_id in self.request_outputs:
            del self.request_outputs[request_id]

Zhuohan Li's avatar
Zhuohan Li committed
196
        # To prevent deadlock when a request is aborted while the engine is
197
198
        # running.
        if self.kicking_request_id == request_id:
Zhuohan Li's avatar
Zhuohan Li committed
199
            self.is_engine_running = False
200
201
            self.kicking_request_id = None

Zhuohan Li's avatar
Zhuohan Li committed
202
    @classmethod
Zhuohan Li's avatar
Zhuohan Li committed
203
204
205
206
207
    def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
        """Creates an async LLM engine from the engine arguments."""
        # Create the engine configs.
        engine_configs = engine_args.create_engine_configs()
        parallel_config = engine_configs[2]
Zhuohan Li's avatar
Zhuohan Li committed
208
        # Initialize the cluster.
209
        distributed_init_method, devices = initialize_cluster(
Zhuohan Li's avatar
Zhuohan Li committed
210
211
212
213
214
215
            parallel_config, engine_args.engine_use_ray)
        # Create the async LLM engine.
        engine = cls(engine_args.worker_use_ray,
                     engine_args.engine_use_ray,
                     not engine_args.disable_log_requests,
                     *engine_configs,
Zhuohan Li's avatar
Zhuohan Li committed
216
                     distributed_init_method, devices,
Zhuohan Li's avatar
Zhuohan Li committed
217
218
                     log_stats=not engine_args.disable_log_stats)
        return engine