Commit c44fc82e authored by Timothy J. Baek's avatar Timothy J. Baek
Browse files

refac: openai

parent 8b6f422d
...@@ -345,108 +345,156 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use ...@@ -345,108 +345,156 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
) )
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.post("/chat/completions")
async def proxy(path: str, request: Request, user=Depends(get_verified_user)): @app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
idx = 0 idx = 0
payload = {**form_data}
body = await request.body() model_id = form_data.get("model")
# TODO: Remove below after gpt-4-vision fix from Open AI model_info = Models.get_model_by_id(model_id)
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
payload = None if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
try: model_info.params = model_info.params.model_dump()
if "chat/completions" in path:
body = body.decode("utf-8")
body = json.loads(body)
payload = {**body} if model_info.params:
if model_info.params.get("temperature", None) is not None:
payload["temperature"] = float(model_info.params.get("temperature"))
model_id = body.get("model") if model_info.params.get("top_p", None):
model_info = Models.get_model_by_id(model_id) payload["top_p"] = int(model_info.params.get("top_p", None))
if model_info: if model_info.params.get("max_tokens", None):
print(model_info) payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump() if model_info.params.get("frequency_penalty", None):
payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params: if model_info.params.get("system", None):
if model_info.params.get("temperature", None) is not None: # Check if the payload already has a system message
payload["temperature"] = float( # If not, add a system message to the payload
model_info.params.get("temperature") if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
) )
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if model_info.params.get("top_p", None): else:
payload["top_p"] = int(model_info.params.get("top_p", None)) pass
if model_info.params.get("max_tokens", None): model = app.state.MODELS[payload.get("model")]
payload["max_tokens"] = int( idx = model["urlIdx"]
model_info.params.get("max_tokens", None)
)
if model_info.params.get("frequency_penalty", None): if "pipeline" in model and model.get("pipeline"):
payload["frequency_penalty"] = int( payload["user"] = {"name": user.name, "id": user.id}
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None): # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
payload["seed"] = model_info.params.get("seed", None) # This is a workaround until OpenAI fixes the issue with this model
if payload.get("model") == "gpt-4-vision-preview":
if model_info.params.get("stop", None): if "max_tokens" not in payload:
payload["stop"] = ( payload["max_tokens"] = 4000
[ log.debug("Modified payload:", payload)
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None): # Convert the modified body back to JSON
# Check if the payload already has a system message payload = json.dumps(payload)
# If not, add a system message to the payload
if payload.get("messages"): print(payload)
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None)
+ message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
else:
pass
model = app.state.MODELS[payload.get("model")] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
idx = model["urlIdx"] print(payload)
if "pipeline" in model and model.get("pipeline"): headers = {}
payload["user"] = {"name": user.name, "id": user.id} headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 r = None
# This is a workaround until OpenAI fixes the issue with this model session = None
if payload.get("model") == "gpt-4-vision-preview": streaming = False
if "max_tokens" not in payload:
payload["max_tokens"] = 4000
log.debug("Modified payload:", payload)
# Convert the modified body back to JSON try:
payload = json.dumps(payload) session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
data=payload,
headers=headers,
)
except json.JSONDecodeError as e: r.raise_for_status()
log.error("Error loading request body into a dictionary:", e)
print(payload) # Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
if not streaming and session:
if r:
r.close()
await session.close()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0
body = await request.body()
url = app.state.config.OPENAI_API_BASE_URLS[idx] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx] key = app.state.config.OPENAI_API_KEYS[idx]
...@@ -466,7 +514,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -466,7 +514,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
r = await session.request( r = await session.request(
method=request.method, method=request.method,
url=target_url, url=target_url,
data=payload if payload else body, data=body,
headers=headers, headers=headers,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment