Commit 7f74426a authored by Jun Siang Cheah's avatar Jun Siang Cheah
Browse files

fix: openai streaming cancellation using aiohttp

parent 4dd51bad
...@@ -153,7 +153,7 @@ async def cleanup_response( ...@@ -153,7 +153,7 @@ async def cleanup_response(
await session.close() await session.close()
async def post_streaming_url(url, payload): async def post_streaming_url(url: str, payload: str):
r = None r = None
try: try:
session = aiohttp.ClientSession() session = aiohttp.ClientSession()
......
...@@ -9,6 +9,7 @@ import json ...@@ -9,6 +9,7 @@ import json
import logging import logging
from pydantic import BaseModel from pydantic import BaseModel
from starlette.background import BackgroundTask
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users from apps.webui.models.users import Users
...@@ -194,6 +195,16 @@ async def fetch_url(url, key): ...@@ -194,6 +195,16 @@ async def fetch_url(url, key):
return None return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def merge_models_lists(model_lists): def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}") log.debug(f"merge_models_lists {model_lists}")
merged_list = [] merged_list = []
...@@ -426,40 +437,45 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -426,40 +437,45 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None r = None
session = None
streaming = False
try: try:
r = requests.request( session = aiohttp.ClientSession()
method=request.method, r = await session.request(
url=target_url, method=request.method, url=target_url, data=payload, headers=headers
data=payload if payload else body,
headers=headers,
stream=True,
) )
r.raise_for_status() r.raise_for_status()
# Check if response is SSE # Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""): if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse( return StreamingResponse(
r.iter_content(chunk_size=8192), r.content,
status_code=r.status_code, status_code=r.status,
headers=dict(r.headers), headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
) )
else: else:
response_data = r.json() response_data = await r.json()
return response_data return response_data
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
res = r.json() res = await r.json()
print(res) print(res)
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except: except:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
raise HTTPException( finally:
status_code=r.status_code if r else 500, detail=error_detail if not streaming and session:
) if r:
r.close()
await session.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment