simple_fastapi_frontend.py 2.03 KB
Newer Older
Zhuohan Li's avatar
Zhuohan Li committed
1
2
3
4
import argparse
import json
from typing import AsyncGenerator

5
from fastapi import BackgroundTasks, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
6
7
8
9
from fastapi.responses import StreamingResponse
import uvicorn

from cacheflow.sampling_params import SamplingParams
10
from cacheflow.server.arg_utils import AsyncServerArgs
Zhuohan Li's avatar
Zhuohan Li committed
11
from cacheflow.server.async_llm_server import AsyncLLMServer
12
from cacheflow.utils import random_uuid
Zhuohan Li's avatar
Zhuohan Li committed
13

14
TIMEOUT_KEEP_ALIVE = 5 # seconds.
Zhuohan Li's avatar
Zhuohan Li committed
15
16
17
18
19
20
21
22
23
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()


@app.post("/generate")
async def generate_stream(request: Request) -> StreamingResponse:
    request_dict = await request.json()
    prompt = request_dict.pop("prompt")
    sampling_params = SamplingParams(**request_dict)
24
25
    request_id = random_uuid()
    results_generator = server.generate(prompt, sampling_params, request_id)
Zhuohan Li's avatar
Zhuohan Li committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39

    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
            text_outputs = [
                prompt + output.text
                for output in request_output.outputs
            ]
            ret = {
                "text": text_outputs,
                "error": 0,
            }
            yield (json.dumps(ret) + "\0").encode("utf-8")

40
41
42
43
44
45
46
    async def abort_request() -> None:
        await server.abort(request_id)

    background_tasks = BackgroundTasks()
    # Abort the request if the client disconnects.
    background_tasks.add_task(abort_request)
    return StreamingResponse(stream_results(), background=background_tasks)
Zhuohan Li's avatar
Zhuohan Li committed
47
48
49
50
51
52


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=8001)
53
    parser = AsyncServerArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
54
55
    args = parser.parse_args()

56
    server_args = AsyncServerArgs.from_cli_args(args)
Zhuohan Li's avatar
Zhuohan Li committed
57
58
    server = AsyncLLMServer.from_server_args(server_args)

59
60
    uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE)