fastapi_frontend.py 7.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import argparse
import asyncio
import time
from typing import List, Dict
import json

import ray
from transformers import AutoTokenizer
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import uvicorn

from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.master.server import (Server, add_server_arguments,
16
17
                                     process_server_arguments,
                                     initialize_cluster)
18
19
20
from cacheflow.worker.controller import DeviceID
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory

21
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
22
23
app = FastAPI()

24

25
26
27
28
29
30
31
32
33
34
35
class FastAPIFrontend:
    def __init__(
        self,
        model: str,
        model_path: str,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
        block_size: int,
        dtype: str,
        seed: int,
        swap_space: int,
36
        max_num_batched_tokens: int,
37
        max_num_sequences: int,
38
39
40
41
        num_nodes: int,
        num_devices_per_node: int,
        distributed_init_method: str,
        all_stage_devices: List[List[DeviceID]],
42
        server_use_ray: bool,
43
44
45
46
47
48
    ):
        self.block_size = block_size

        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.seq_group_counter = Counter()
        self.seq_counter = Counter()
49
50
51
52
        if server_use_ray:
            remote_server_class = ray.remote(num_cpus=0)(Server)
        else:
            remote_server_class = ray.remote(num_gpus=1)(Server)
53
54
55
        self.server = remote_server_class.remote(
            model=model,
            model_path=model_path,
56
            use_dummy_weights=False,
57
58
59
60
61
62
            pipeline_parallel_size=pipeline_parallel_size,
            tensor_parallel_size=tensor_parallel_size,
            block_size=block_size,
            dtype=dtype,
            seed=seed,
            swap_space=swap_space,
63
            max_num_batched_tokens=max_num_batched_tokens,
64
            max_num_sequences=max_num_sequences,
65
66
67
68
69
70
            num_nodes=num_nodes,
            num_devices_per_node=num_devices_per_node,
            distributed_init_method=distributed_init_method,
            all_stage_devices=all_stage_devices,
            gpu_memory=get_gpu_memory(),
            cpu_memory=get_cpu_memory(),
71
            use_ray=server_use_ray,
72
73
74
75
76
77
78
79
80
81
        )

        self.running_seq_groups: Dict[int, SequenceGroup] = {}
        self.sequence_group_events: Dict[int, asyncio.Event] = {}
        self.is_server_running = False

    async def server_step(self):
        self.is_server_running = True
        updated_seq_groups = await self.server.step.remote()
        self.is_server_running = False
82
        # Notify the waiting coroutines that there new outputs ready.
83
84
85
86
87
88
        for seq_group in updated_seq_groups:
            group_id = seq_group.group_id
            self.running_seq_groups[group_id] = seq_group
            self.sequence_group_events[group_id].set()

    async def generate(self, request_dict: Dict):
89
        # Preprocess the request.
90
91
92
93
94
95
96
97
98
99
        prompt = request_dict["prompt"]
        sampling_params = SamplingParams.from_dict(request_dict)
        sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
        token_ids = self.tokenizer.encode(prompt)
        seqs: List[Sequence] = []
        for _ in range(sampling_params.n):
            seq_id = next(self.seq_counter)
            seq = Sequence(seq_id, token_ids, block_size=self.block_size)
            seqs.append(seq)

100
        arrival_time = time.time()
101
        group_id = next(self.seq_group_counter)
102
        seq_group = SequenceGroup(group_id, seqs, arrival_time)
103
104
        # Create an event to notify us that there is new output from the
        # cacheflow server.
105
        group_event = asyncio.Event()
106
        self.running_seq_groups[group_id] = seq_group
107
        self.sequence_group_events[group_id] = group_event
108
        # Add the request into the cacheflow server's waiting queue.
109
        await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
110
111
112
        # The cacheflow server does not have a background loop that keeps
        # processing incoming requests. Therefore, we need to keep kicking
        # the server to process the requests.
113
        while True:
114
            # Kick the server if the server is not running.
115
116
            if not self.is_server_running:
                await self.server_step()
117
118
119
120
121
            # Wait for new output. The group_event will be set in server_step
            # when there is new output available for the sequence group.
            # Added a timeout to prevent deadlock.
            await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
            # Reset the event to wait for the next output.
122
            group_event.clear()
123
            # Decode and return new outputs
124
125
126
127
128
129
130
131
132
133
134
            seq_group = self.running_seq_groups[group_id]
            all_outputs = []
            for seq in seq_group.seqs:
                token_ids = seq.get_token_ids()
                output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
                all_outputs.append(output)
            ret = {
                "text": all_outputs,
                "error": 0,
            }
            yield (json.dumps(ret) + "\0").encode("utf-8")
135
136

            # Once finished, release the resources of the sequence group.
137
            if seq_group.is_finished():
138
139
140
141
142
143
144
                del self.running_seq_groups[group_id]
                del self.sequence_group_events[group_id]
                # Kick the server if the server is not running. This is to
                # prevent that there are still requests in server's waiting
                # queue to be executed.
                if not self.is_server_running:
                    await self.server_step()
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
                break


@app.post("/generate")
async def generate_stream(request: Request):
    request_dict = await request.json()
    return StreamingResponse(frontend.generate(request_dict))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=10002)
    parser = add_server_arguments(parser)
    args = parser.parse_args()
160
    args = process_server_arguments(args)
161
162
163
164
165
166
167

    # TODO(zhuohan): Support pipeline parallelism.
    assert args.pipeline_parallel_size == 1, (
        'Pipeline parallelism is not supported yet.')

    (num_nodes, num_devices_per_node, distributed_init_method,
    all_stage_devices) = (
168
169
        initialize_cluster(
            use_ray=True,
170
171
172
173
174
175
176
177
178
179
180
181
            pipeline_parallel_size=args.pipeline_parallel_size,
            tensor_parallel_size=args.tensor_parallel_size))

    frontend = FastAPIFrontend(
        model=args.model,
        model_path=args.model_path,
        pipeline_parallel_size=args.pipeline_parallel_size,
        tensor_parallel_size=args.tensor_parallel_size,
        block_size=args.block_size,
        dtype=args.dtype,
        seed=args.seed,
        swap_space=args.swap_space,
182
        max_num_batched_tokens=args.max_num_batched_tokens,
183
        max_num_sequences=args.max_num_sequences,
184
185
186
187
        num_nodes=num_nodes,
        num_devices_per_node=num_devices_per_node,
        distributed_init_method=distributed_init_method,
        all_stage_devices=all_stage_devices,
188
        server_use_ray=args.use_ray,
189
190
191
    )

    uvicorn.run(app, host=args.host, port=args.port, log_level="info")