"tests/nn/vscode:/vscode.git/clone" did not exist on "c5e471bc8cfe8bb8500de57437c1bbc2b034514a"
fastapi_frontend.py 7.44 KB
Newer Older
1
2
import argparse
import asyncio
3
import json
4
import time
5
from typing import List, Dict, Optional
6
7
8

from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
9
10
import ray
from transformers import AutoTokenizer
11
12
import uvicorn

13
14
15
from cacheflow.core.server import (Server, add_server_arguments,
                                   process_server_arguments,
                                   initialize_cluster)
16
17
18
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
19
from cacheflow.worker.controller import DeviceID
20

21
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
22
23
app = FastAPI()

24

25
class FastAPIServer:
26
27
28
    def __init__(
        self,
        model: str,
29
30
        cache_dir: Optional[str],
        use_np_cache: bool,
31
32
33
34
35
36
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
        block_size: int,
        dtype: str,
        seed: int,
        swap_space: int,
37
        max_num_batched_tokens: int,
38
        max_num_sequences: int,
39
40
41
42
        num_nodes: int,
        num_devices_per_node: int,
        distributed_init_method: str,
        all_stage_devices: List[List[DeviceID]],
43
        server_use_ray: bool,
44
45
46
47
48
49
    ):
        self.block_size = block_size

        self.tokenizer = AutoTokenizer.from_pretrained(model)
        self.seq_group_counter = Counter()
        self.seq_counter = Counter()
50
51
52
53
        if server_use_ray:
            remote_server_class = ray.remote(num_cpus=0)(Server)
        else:
            remote_server_class = ray.remote(num_gpus=1)(Server)
54
55
        self.server = remote_server_class.remote(
            model=model,
56
            cache_dir=cache_dir,
57
            use_dummy_weights=False,
58
            use_np_cache=use_np_cache,
59
60
61
62
63
64
            pipeline_parallel_size=pipeline_parallel_size,
            tensor_parallel_size=tensor_parallel_size,
            block_size=block_size,
            dtype=dtype,
            seed=seed,
            swap_space=swap_space,
65
            max_num_batched_tokens=max_num_batched_tokens,
66
            max_num_sequences=max_num_sequences,
67
68
69
70
71
72
            num_nodes=num_nodes,
            num_devices_per_node=num_devices_per_node,
            distributed_init_method=distributed_init_method,
            all_stage_devices=all_stage_devices,
            gpu_memory=get_gpu_memory(),
            cpu_memory=get_cpu_memory(),
73
            use_ray=server_use_ray,
74
75
76
77
78
79
80
81
82
83
        )

        self.running_seq_groups: Dict[int, SequenceGroup] = {}
        self.sequence_group_events: Dict[int, asyncio.Event] = {}
        self.is_server_running = False

    async def server_step(self):
        self.is_server_running = True
        updated_seq_groups = await self.server.step.remote()
        self.is_server_running = False
84
        # Notify the waiting coroutines that there new outputs ready.
85
86
87
88
89
90
        for seq_group in updated_seq_groups:
            group_id = seq_group.group_id
            self.running_seq_groups[group_id] = seq_group
            self.sequence_group_events[group_id].set()

    async def generate(self, request_dict: Dict):
91
        # Preprocess the request.
92
93
94
95
96
97
98
99
100
101
        prompt = request_dict["prompt"]
        sampling_params = SamplingParams.from_dict(request_dict)
        sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
        token_ids = self.tokenizer.encode(prompt)
        seqs: List[Sequence] = []
        for _ in range(sampling_params.n):
            seq_id = next(self.seq_counter)
            seq = Sequence(seq_id, token_ids, block_size=self.block_size)
            seqs.append(seq)

102
        arrival_time = time.time()
103
        group_id = next(self.seq_group_counter)
104
        seq_group = SequenceGroup(group_id, seqs, arrival_time)
105
106
        # Create an event to notify us that there is new output from the
        # cacheflow server.
107
        group_event = asyncio.Event()
108
        self.running_seq_groups[group_id] = seq_group
109
        self.sequence_group_events[group_id] = group_event
110
        # Add the request into the cacheflow server's waiting queue.
111
        await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
112
113
114
        # The cacheflow server does not have a background loop that keeps
        # processing incoming requests. Therefore, we need to keep kicking
        # the server to process the requests.
115
        while True:
116
            # Kick the server if the server is not running.
117
118
            if not self.is_server_running:
                await self.server_step()
119
120
121
122
123
            # Wait for new output. The group_event will be set in server_step
            # when there is new output available for the sequence group.
            # Added a timeout to prevent deadlock.
            await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
            # Reset the event to wait for the next output.
124
            group_event.clear()
125
            # Decode and return new outputs
126
127
128
129
130
131
132
133
134
135
136
            seq_group = self.running_seq_groups[group_id]
            all_outputs = []
            for seq in seq_group.seqs:
                token_ids = seq.get_token_ids()
                output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
                all_outputs.append(output)
            ret = {
                "text": all_outputs,
                "error": 0,
            }
            yield (json.dumps(ret) + "\0").encode("utf-8")
137
138

            # Once finished, release the resources of the sequence group.
139
            if seq_group.is_finished():
140
141
142
143
144
145
146
                del self.running_seq_groups[group_id]
                del self.sequence_group_events[group_id]
                # Kick the server if the server is not running. This is to
                # prevent that there are still requests in server's waiting
                # queue to be executed.
                if not self.is_server_running:
                    await self.server_step()
147
148
149
150
151
152
                break


@app.post("/generate")
async def generate_stream(request: Request):
    request_dict = await request.json()
153
    return StreamingResponse(server.generate(request_dict))
154
155
156
157
158
159
160
161


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=10002)
    parser = add_server_arguments(parser)
    args = parser.parse_args()
162
    args = process_server_arguments(args)
163
164
165
166
167
168
169

    # TODO(zhuohan): Support pipeline parallelism.
    assert args.pipeline_parallel_size == 1, (
        'Pipeline parallelism is not supported yet.')

    (num_nodes, num_devices_per_node, distributed_init_method,
    all_stage_devices) = (
170
171
        initialize_cluster(
            use_ray=True,
172
173
174
            pipeline_parallel_size=args.pipeline_parallel_size,
            tensor_parallel_size=args.tensor_parallel_size))

175
    server = FastAPIServer(
176
        model=args.model,
177
178
        cache_dir=args.cache_dir,
        use_np_cache=args.use_np_cache,
179
180
181
182
183
184
        pipeline_parallel_size=args.pipeline_parallel_size,
        tensor_parallel_size=args.tensor_parallel_size,
        block_size=args.block_size,
        dtype=args.dtype,
        seed=args.seed,
        swap_space=args.swap_space,
185
        max_num_batched_tokens=args.max_num_batched_tokens,
186
        max_num_sequences=args.max_num_sequences,
187
188
189
190
        num_nodes=num_nodes,
        num_devices_per_node=num_devices_per_node,
        distributed_init_method=distributed_init_method,
        all_stage_devices=all_stage_devices,
191
        server_use_ray=args.use_ray,
192
193
194
    )

    uvicorn.run(app, host=args.host, port=args.port, log_level="info")