fastapi_frontend.py 7.66 KB
Newer Older
1
2
import argparse
import asyncio
3
import json
4
import time
5
from typing import List, Dict, Optional
6
7
8

from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
9
import ray
10
11
import uvicorn

12
13
14
from cacheflow.core.server import (Server, add_server_arguments,
                                   process_server_arguments,
                                   initialize_cluster)
15
from cacheflow.frontend.utils import get_tokenizer
16
17
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
18
from cacheflow.utils import Counter
19
from cacheflow.worker.controller import DeviceID
20

21
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
22
23
app = FastAPI()

24

25
class FastAPIServer:
26
27
28
    def __init__(
        self,
        model: str,
29
30
        cache_dir: Optional[str],
        use_np_cache: bool,
31
32
33
34
35
36
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
        block_size: int,
        dtype: str,
        seed: int,
        swap_space: int,
37
        gpu_memory_utilization: float,
38
        max_num_batched_tokens: int,
39
        max_num_sequences: int,
40
41
42
43
        num_nodes: int,
        num_devices_per_node: int,
        distributed_init_method: str,
        all_stage_devices: List[List[DeviceID]],
44
        server_use_ray: bool,
45
        log_stats: bool,
46
47
48
    ):
        self.block_size = block_size

49
        self.tokenizer = get_tokenizer(model)
50
51
        self.seq_group_counter = Counter()
        self.seq_counter = Counter()
52
53
54
55
        if server_use_ray:
            remote_server_class = ray.remote(num_cpus=0)(Server)
        else:
            remote_server_class = ray.remote(num_gpus=1)(Server)
56
57
        self.server = remote_server_class.remote(
            model=model,
58
            cache_dir=cache_dir,
59
            use_dummy_weights=False,
60
            use_np_cache=use_np_cache,
61
62
63
64
65
66
            pipeline_parallel_size=pipeline_parallel_size,
            tensor_parallel_size=tensor_parallel_size,
            block_size=block_size,
            dtype=dtype,
            seed=seed,
            swap_space=swap_space,
67
            gpu_memory_utilization=gpu_memory_utilization,
68
            max_num_batched_tokens=max_num_batched_tokens,
69
            max_num_sequences=max_num_sequences,
70
71
72
73
            num_nodes=num_nodes,
            num_devices_per_node=num_devices_per_node,
            distributed_init_method=distributed_init_method,
            all_stage_devices=all_stage_devices,
74
            use_ray=server_use_ray,
75
            log_stats=log_stats,
76
77
78
79
80
81
82
83
84
85
        )

        self.running_seq_groups: Dict[int, SequenceGroup] = {}
        self.sequence_group_events: Dict[int, asyncio.Event] = {}
        self.is_server_running = False

    async def server_step(self):
        self.is_server_running = True
        updated_seq_groups = await self.server.step.remote()
        self.is_server_running = False
86
        # Notify the waiting coroutines that there are new outputs ready.
87
88
89
90
91
92
        for seq_group in updated_seq_groups:
            group_id = seq_group.group_id
            self.running_seq_groups[group_id] = seq_group
            self.sequence_group_events[group_id].set()

    async def generate(self, request_dict: Dict):
93
        # Preprocess the request.
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
        prompt = request_dict.pop("prompt")
        sampling_params = SamplingParams(**request_dict)
96
97
98
99
100
        sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
        token_ids = self.tokenizer.encode(prompt)
        seqs: List[Sequence] = []
        for _ in range(sampling_params.n):
            seq_id = next(self.seq_counter)
101
            seq = Sequence(seq_id, prompt, token_ids, block_size=self.block_size)
102
103
            seqs.append(seq)

104
        arrival_time = time.time()
105
        group_id = next(self.seq_group_counter)
106
        seq_group = SequenceGroup(group_id, seqs, arrival_time)
107
108
        # Create an event to notify us that there is new output from the
        # cacheflow server.
109
        group_event = asyncio.Event()
110
        self.running_seq_groups[group_id] = seq_group
111
        self.sequence_group_events[group_id] = group_event
112
        # Add the request into the cacheflow server's waiting queue.
113
        await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
114
115
116
        # The cacheflow server does not have a background loop that keeps
        # processing incoming requests. Therefore, we need to keep kicking
        # the server to process the requests.
117
        while True:
118
            # Kick the server if the server is not running.
119
120
            if not self.is_server_running:
                await self.server_step()
121
122
123
            # Wait for new output. The group_event will be set in server_step
            # when there is new output available for the sequence group.
            # Added a timeout to prevent deadlock.
124
125
126
127
            try:
                await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
            except asyncio.TimeoutError:
                continue
128
            # Reset the event to wait for the next output.
129
            group_event.clear()
130
            # Decode and return new outputs
131
132
133
134
135
136
137
138
139
140
141
            seq_group = self.running_seq_groups[group_id]
            all_outputs = []
            for seq in seq_group.seqs:
                token_ids = seq.get_token_ids()
                output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
                all_outputs.append(output)
            ret = {
                "text": all_outputs,
                "error": 0,
            }
            yield (json.dumps(ret) + "\0").encode("utf-8")
142
143

            # Once finished, release the resources of the sequence group.
144
            if seq_group.is_finished():
145
146
147
148
149
150
151
                del self.running_seq_groups[group_id]
                del self.sequence_group_events[group_id]
                # Kick the server if the server is not running. This is to
                # prevent that there are still requests in server's waiting
                # queue to be executed.
                if not self.is_server_running:
                    await self.server_step()
152
153
154
155
156
157
                break


@app.post("/generate")
async def generate_stream(request: Request):
    request_dict = await request.json()
158
    return StreamingResponse(server.generate(request_dict))
159
160
161
162
163
164
165
166


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=10002)
    parser = add_server_arguments(parser)
    args = parser.parse_args()
167
    args = process_server_arguments(args)
168
169
170
171
172
173
174

    # TODO(zhuohan): Support pipeline parallelism.
    assert args.pipeline_parallel_size == 1, (
        'Pipeline parallelism is not supported yet.')

    (num_nodes, num_devices_per_node, distributed_init_method,
    all_stage_devices) = (
175
176
        initialize_cluster(
            use_ray=True,
177
178
179
            pipeline_parallel_size=args.pipeline_parallel_size,
            tensor_parallel_size=args.tensor_parallel_size))

180
    server = FastAPIServer(
181
        model=args.model,
182
183
        cache_dir=args.cache_dir,
        use_np_cache=args.use_np_cache,
184
185
186
187
188
189
        pipeline_parallel_size=args.pipeline_parallel_size,
        tensor_parallel_size=args.tensor_parallel_size,
        block_size=args.block_size,
        dtype=args.dtype,
        seed=args.seed,
        swap_space=args.swap_space,
190
        gpu_memory_utilization=args.gpu_memory_utilization,
191
        max_num_batched_tokens=args.max_num_batched_tokens,
192
        max_num_sequences=args.max_num_sequences,
193
194
195
196
        num_nodes=num_nodes,
        num_devices_per_node=num_devices_per_node,
        distributed_init_method=distributed_init_method,
        all_stage_devices=all_stage_devices,
197
        server_use_ray=args.use_ray,
198
        log_stats=args.log_stats,
199
200
201
    )

    uvicorn.run(app, host=args.host, port=args.port, log_level="info")