Commit e8202731 authored by Timothy J. Baek's avatar Timothy J. Baek
Browse files

fix

parent c6c0bc19
...@@ -196,7 +196,11 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, ...@@ -196,7 +196,11 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
"stream": False, "stream": False,
} }
try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user)
except Exception as e:
raise e
model = app.state.MODELS[task_model_id] model = app.state.MODELS[task_model_id]
response = None response = None
...@@ -326,6 +330,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -326,6 +330,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
print(data["tool_ids"]) print(data["tool_ids"])
for tool_id in data["tool_ids"]: for tool_id in data["tool_ids"]:
print(tool_id) print(tool_id)
try:
response = await get_function_call_response( response = await get_function_call_response(
messages=data["messages"], messages=data["messages"],
tool_id=tool_id, tool_id=tool_id,
...@@ -336,6 +341,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -336,6 +341,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if response: if response:
context += ("\n" if context != "" else "") + response context += ("\n" if context != "" else "") + response
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"] del data["tool_ids"]
print(f"tool_context: {context}") print(f"tool_context: {context}")
...@@ -767,7 +774,14 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -767,7 +774,14 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
} }
print(payload) print(payload)
try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion( return await generate_ollama_chat_completion(
...@@ -824,7 +838,14 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -824,7 +838,14 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
} }
print(payload) print(payload)
try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion( return await generate_ollama_chat_completion(
...@@ -861,9 +882,16 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ ...@@ -861,9 +882,16 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
print(model_id) print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
return await get_function_call_response( try:
context = await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user form_data["messages"], form_data["tool_id"], template, model_id, user
) )
return context
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
@app.post("/api/chat/completions") @app.post("/api/chat/completions")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment