api_server.py 2.92 KB
Newer Older
Zhuohan Li's avatar
Zhuohan Li committed
1
2
3
4
import argparse
import json
from typing import AsyncGenerator

5
from fastapi import BackgroundTasks, FastAPI, Request
6
from fastapi.responses import Response, StreamingResponse
Zhuohan Li's avatar
Zhuohan Li committed
7
8
9
import uvicorn

from cacheflow.sampling_params import SamplingParams
10
from cacheflow.server.arg_utils import AsyncServerArgs
11
from cacheflow.server.async_llm_server import AsyncLLMEngine
12
from cacheflow.utils import random_uuid
Zhuohan Li's avatar
Zhuohan Li committed
13

14
TIMEOUT_KEEP_ALIVE = 5 # seconds.
Zhuohan Li's avatar
Zhuohan Li committed
15
16
17
18
19
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()


@app.post("/generate")
20
async def generate(request: Request) -> Response:
21
    """Generate completion for the request.
22
23
24

    The request should be a JSON object with the following fields:
    - prompt: the prompt to use for the generation.
25
    - stream: whether to stream the results or not.
26
27
    - other fields: the sampling parameters (See `SamplingParams` for details).
    """
Zhuohan Li's avatar
Zhuohan Li committed
28
29
    request_dict = await request.json()
    prompt = request_dict.pop("prompt")
30
    stream = request_dict.pop("stream", False)
Zhuohan Li's avatar
Zhuohan Li committed
31
    sampling_params = SamplingParams(**request_dict)
32
33
    request_id = random_uuid()
    results_generator = server.generate(prompt, sampling_params, request_id)
Zhuohan Li's avatar
Zhuohan Li committed
34

35
    # Streaming case
Zhuohan Li's avatar
Zhuohan Li committed
36
37
38
39
40
41
42
    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
            text_outputs = [
                prompt + output.text
                for output in request_output.outputs
            ]
43
            ret = {"text": text_outputs}
Zhuohan Li's avatar
Zhuohan Li committed
44
45
            yield (json.dumps(ret) + "\0").encode("utf-8")

46
47
48
    async def abort_request() -> None:
        await server.abort(request_id)

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    if stream:
        background_tasks = BackgroundTasks()
        # Abort the request if the client disconnects.
        background_tasks.add_task(abort_request)
        return StreamingResponse(stream_results(), background=background_tasks)

    # Non-streaming case
    final_output = None
    async for request_output in results_generator:
        if await request.is_disconnected():
            # Abort the request if the client disconnects.
            await server.abort(request_id)
            return Response(status_code=499)
        final_output = request_output

    assert final_output is not None
    prompt = final_output.prompt
    text_outputs = [
        prompt + output.text
        for output in final_output.outputs
    ]
    ret = {"text": text_outputs}
    return Response(content=json.dumps(ret))
Zhuohan Li's avatar
Zhuohan Li committed
72
73
74
75
76


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
77
    parser.add_argument("--port", type=int, default=8000)
78
    parser = AsyncServerArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
79
80
    args = parser.parse_args()

81
    server_args = AsyncServerArgs.from_cli_args(args)
82
    server = AsyncLLMEngine.from_server_args(server_args)
Zhuohan Li's avatar
Zhuohan Li committed
83

84
85
    uvicorn.run(app, host=args.host, port=args.port, log_level="debug",
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE)