Unverified Commit 5020e1e8 authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

Non-streaming simple fastapi server (#144)

parent 42983742
...@@ -233,7 +233,7 @@ async def create_completion(raw_request: Request): ...@@ -233,7 +233,7 @@ async def create_completion(raw_request: Request):
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await server.abort(request_id) await abort_request()
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected") "Client disconnected")
final_res = res final_res = res
......
...@@ -3,7 +3,7 @@ import json ...@@ -3,7 +3,7 @@ import json
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import Response, StreamingResponse
import uvicorn import uvicorn
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
...@@ -17,19 +17,22 @@ app = FastAPI() ...@@ -17,19 +17,22 @@ app = FastAPI()
@app.post("/generate") @app.post("/generate")
async def generate_stream(request: Request) -> StreamingResponse: async def generate(request: Request) -> Response:
""" Stream the results of the generation request. """ Stream the results of the generation request.
The request should be a JSON object with the following fields: The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation. - prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details). - other fields: the sampling parameters (See `SamplingParams` for details).
""" """
request_dict = await request.json() request_dict = await request.json()
prompt = request_dict.pop("prompt") prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
results_generator = server.generate(prompt, sampling_params, request_id) results_generator = server.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]: async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator: async for request_output in results_generator:
prompt = request_output.prompt prompt = request_output.prompt
...@@ -37,20 +40,36 @@ async def generate_stream(request: Request) -> StreamingResponse: ...@@ -37,20 +40,36 @@ async def generate_stream(request: Request) -> StreamingResponse:
prompt + output.text prompt + output.text
for output in request_output.outputs for output in request_output.outputs
] ]
ret = { ret = {"text": text_outputs}
"text": text_outputs,
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None: async def abort_request() -> None:
await server.abort(request_id) await server.abort(request_id)
if stream:
background_tasks = BackgroundTasks() background_tasks = BackgroundTasks()
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
background_tasks.add_task(abort_request) background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks) return StreamingResponse(stream_results(), background=background_tasks)
# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await server.abort(request_id)
return Response(status_code=499)
final_output = request_output
assert final_output is not None
prompt = final_output.prompt
text_outputs = [
prompt + output.text
for output in final_output.outputs
]
ret = {"text": text_outputs}
return Response(content=json.dumps(ret))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
import argparse import argparse
import requests
import json import json
import requests
from typing import Iterable, List
def clear_line(n=1): def clear_line(n:int = 1) -> None:
LINE_UP = '\033[1A' LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K' LINE_CLEAR = '\x1b[2K'
for i in range(n): for i in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True) print(LINE_UP, end=LINE_CLEAR, flush=True)
def http_request(prompt: str, api_url: str, n: int = 1): def post_http_request(prompt: str, api_url: str, n: int = 1,
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"} headers = {"User-Agent": "Test Client"}
pload = { pload = {
"prompt": prompt, "prompt": prompt,
...@@ -17,32 +19,52 @@ def http_request(prompt: str, api_url: str, n: int = 1): ...@@ -17,32 +19,52 @@ def http_request(prompt: str, api_url: str, n: int = 1):
"use_beam_search": True, "use_beam_search": True,
"temperature": 0.0, "temperature": 0.0,
"max_tokens": 16, "max_tokens": 16,
"stream": stream,
} }
response = requests.post(api_url, headers=headers, json=pload, stream=True) response = requests.post(api_url, headers=headers, json=pload, stream=True)
return response
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False,
delimiter=b"\0"):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))
output = data["text"] output = data["text"]
yield output yield output
def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data["text"]
return output
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--n", type=int, default=4) parser.add_argument("--n", type=int, default=4)
parser.add_argument("--prompt", type=str, default="San Francisco is a") parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true")
args = parser.parse_args() args = parser.parse_args()
prompt = args.prompt prompt = args.prompt
api_url = f"http://{args.host}:{args.port}/generate" api_url = f"http://{args.host}:{args.port}/generate"
n = args.n n = args.n
stream = args.stream
print(f"Prompt: {prompt}\n", flush=True) print(f"Prompt: {prompt}\n", flush=True)
response = post_http_request(prompt, api_url, n, stream)
if stream:
num_printed_lines = 0 num_printed_lines = 0
for h in http_request(prompt, api_url, n): for h in get_streaming_response(response):
clear_line(num_printed_lines) clear_line(num_printed_lines)
num_printed_lines = 0 num_printed_lines = 0
for i, line in enumerate(h): for i, line in enumerate(h):
num_printed_lines += 1 num_printed_lines += 1
print(f"Beam candidate {i}: {line}", flush=True) print(f"Beam candidate {i}: {line}", flush=True)
else:
output = get_response(response)
for i, line in enumerate(output):
print(f"Beam candidate {i}: {line}", flush=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment