"docs/vscode:/vscode.git/clone" did not exist on "9fa5f3f072a9473fd2cf8763e85baaab294c931c"
Unverified Commit 2d1e86f1 authored by Roy's avatar Roy Committed by GitHub
Browse files

clean api code, remove redundant background task. (#1102)

parent 1ac4ccf7
...@@ -2,7 +2,7 @@ import argparse ...@@ -2,7 +2,7 @@ import argparse
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn import uvicorn
...@@ -44,14 +44,8 @@ async def generate(request: Request) -> Response: ...@@ -44,14 +44,8 @@ async def generate(request: Request) -> Response:
ret = {"text": text_outputs} ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8") yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await engine.abort(request_id)
if stream: if stream:
background_tasks = BackgroundTasks() return StreamingResponse(stream_results())
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
# Non-streaming case # Non-streaming case
final_output = None final_output = None
......
...@@ -10,7 +10,7 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union ...@@ -10,7 +10,7 @@ from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import uvicorn import uvicorn
from fastapi import BackgroundTasks, Request from fastapi import Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
...@@ -229,9 +229,6 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -229,9 +229,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
result_generator = engine.generate(prompt, sampling_params, request_id, result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids) token_ids)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json( def create_stream_response_json(
index: int, index: int,
text: str, text: str,
...@@ -291,19 +288,15 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -291,19 +288,15 @@ async def create_chat_completion(request: ChatCompletionRequest,
# Streaming response # Streaming response
if request.stream: if request.stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(), return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream", media_type="text/event-stream")
background=background_tasks)
# Non-streaming response # Non-streaming response
final_res: RequestOutput = None final_res: RequestOutput = None
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await abort_request() await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected") "Client disconnected")
final_res = res final_res = res
...@@ -448,9 +441,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -448,9 +441,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
and (request.best_of is None or request.n == request.best_of) and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search) and not request.use_beam_search)
async def abort_request() -> None:
await engine.abort(request_id)
def create_stream_response_json( def create_stream_response_json(
index: int, index: int,
text: str, text: str,
...@@ -510,19 +500,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -510,19 +500,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
# Streaming response # Streaming response
if stream: if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(completion_stream_generator(), return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream", media_type="text/event-stream")
background=background_tasks)
# Non-streaming response # Non-streaming response
final_res: RequestOutput = None final_res: RequestOutput = None
async for res in result_generator: async for res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await abort_request() await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST, return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected") "Client disconnected")
final_res = res final_res = res
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment