async_llm_server.py 6.13 KB
Newer Older
1
2
import asyncio
import time
Zhuohan Li's avatar
Zhuohan Li committed
3
from typing import Dict, Optional
4

5
from cacheflow.logger import init_logger
6
7
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
8
from cacheflow.server.arg_utils import AsyncServerArgs
9
from cacheflow.server.llm_server import LLMServer
10
11
12
from cacheflow.server.ray_utils import ray, initialize_cluster

logger = init_logger(__name__)
13
14
15
16

TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds


Zhuohan Li's avatar
Zhuohan Li committed
17
class AsyncLLMServer:
18

19
20
21
22
23
24
25
26
    def __init__(self, worker_use_ray: bool, server_use_ray: bool,
                 *args, **kwargs) -> None:
        self.worker_use_ray = worker_use_ray
        self.server_use_ray = server_use_ray
        if not self.server_use_ray:
            server_class = LLMServer
        elif self.worker_use_ray:
            server_class = ray.remote(num_cpus=0)(LLMServer).remote
27
        else:
28
29
            server_class = ray.remote(num_gpus=1)(LLMServer).remote
        self.server = server_class(*args, **kwargs)
30
31
32
33
34
        # Request id -> request output.
        self.request_outputs: Dict[str, RequestOutput] = {}
        # Request id -> event to notify that there is new output.
        self.request_events: Dict[str, asyncio.Event] = {}
        self.is_server_running = False
35
        self.kicking_request_id: Optional[str] = None
36

37
    async def server_step(self, kicking_request_id: Optional[str] = None):
38
        self.is_server_running = True
39
40
41
42
43
44
45
46
47
        self.kicking_request_id = kicking_request_id
        if self.server_use_ray:
            request_outputs = await self.server.step.remote()
        else:
            # Yield to the event loop to allow other coroutines to run
            # while is_server_running is True. This let the server to add new
            # requests into the queue.
            await asyncio.sleep(0)
            request_outputs = self.server.step()
48
        self.is_server_running = False
49
50
        self.kicking_request_id = None

51
52
53
54
55
56
        # Notify the waiting coroutines that there are new outputs ready.
        for request_output in request_outputs:
            request_id = request_output.request_id
            self.request_outputs[request_id] = request_output
            self.request_events[request_id].set()

Zhuohan Li's avatar
Zhuohan Li committed
57
    async def generate(self, prompt: str, sampling_params: SamplingParams,
58
                       request_id: str) -> RequestOutput:
59
60
61
62
63
64
65
66
        # Preprocess the request.
        arrival_time = time.time()

        # Create an event to notify us that there is new output from the
        # cacheflow server.
        request_event = asyncio.Event()
        self.request_events[request_id] = request_event

67
68
69
70
        logger.info(f"Received request {request_id}: "
                    f"prompt: {prompt!r}, "
                    f"sampling params: {sampling_params}.")

71
        # Add the request into the cacheflow server's waiting queue.
72
73
74
75
76
77
        if self.server_use_ray:
            await self.server.add_request.remote(
                request_id, prompt, sampling_params, arrival_time=arrival_time)
        else:
            self.server.add_request(
                request_id, prompt, sampling_params, arrival_time=arrival_time)
78
79
80
81
82
83
84

        # The cacheflow server does not have a background loop that keeps
        # processing incoming requests. Therefore, we need to keep kicking
        # the server to process the requests.
        while True:
            # Kick the server if the server is not running.
            if not self.is_server_running:
85
                await self.server_step(request_id)
86
87
88
89
90
91
92
93
94
95
96
97
98
99

            # Wait for new output. The group_event will be set in server_step
            # when there is new output available for the sequence group.
            # Added a timeout to prevent deadlock.
            try:
                await asyncio.wait_for(request_event.wait(),
                                       timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
            except asyncio.TimeoutError:
                continue
            # Reset the event to wait for the next output.
            request_event.clear()

            # Decode and return new outputs.
            request_output = self.request_outputs[request_id]
Zhuohan Li's avatar
Zhuohan Li committed
100
            yield request_output
101
102

            # Once finished, release the resources of the sequence group.
Zhuohan Li's avatar
Zhuohan Li committed
103
            if request_output.finished():
104
105
                logger.info(f"Finished request {request_id}.")

106
107
108
109
110
111
112
113
114
                del self.request_outputs[request_id]
                del self.request_events[request_id]
                # Kick the server if the server is not running. This is to
                # prevent that there are still requests in server's waiting
                # queue to be executed.
                if not self.is_server_running:
                    await self.server_step()
                break

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    async def abort(self, request_id: str) -> None:
        if request_id not in self.request_events:
            # The request has already finished or been aborted.
            return

        logger.info(f"Aborted request {request_id}.")

        if self.server_use_ray:
            await self.server.abort_request.remote(request_id)
        else:
            self.server.abort_request(request_id)

        if request_id in self.request_events:
            del self.request_events[request_id]
        if request_id in self.request_outputs:
            del self.request_outputs[request_id]

        # To prevent deadlock when a request is aborted while the server is
        # running.
        if self.kicking_request_id == request_id:
            self.is_server_running = False
            self.kicking_request_id = None

Zhuohan Li's avatar
Zhuohan Li committed
138
    @classmethod
139
    def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer":
Zhuohan Li's avatar
Zhuohan Li committed
140
141
142
143
        # Create the server configs.
        server_configs = server_args.create_server_configs()
        parallel_config = server_configs[2]
        # Initialize the cluster.
144
145
        distributed_init_method, devices = initialize_cluster(
            parallel_config, server_args.server_use_ray)
Zhuohan Li's avatar
Zhuohan Li committed
146
        # Create the LLM server.
147
148
149
        server = cls(server_args.worker_use_ray,
                     server_args.server_use_ray,
                     *server_configs,
Zhuohan Li's avatar
Zhuohan Li committed
150
151
152
                     distributed_init_method, devices,
                     log_stats=not server_args.disable_log_stats)
        return server