Commit c83704d6 authored by Timothy J. Baek's avatar Timothy J. Baek
Browse files
parent d0e0aba5
......@@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum):
OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature."
)
class TASKS(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'default'}"
TITLE_GENERATION = "Title Generation"
EMOJI_GENERATION = "Emoji Generation"
QUERY_GENERATION = "Query Generation"
FUNCTION_CALLING = "Function Calling"
......@@ -126,7 +126,7 @@ from config import (
WEBUI_SESSION_COOKIE_SECURE,
AppConfig,
)
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook
if SAFE_MODE:
......@@ -311,6 +311,7 @@ async def get_function_call_response(
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
"task": TASKS.FUNCTION_CALLING,
}
try:
......@@ -323,7 +324,6 @@ async def get_function_call_response(
response = None
try:
response = await generate_chat_completions(form_data=payload, user=user)
content = None
if hasattr(response, "body_iterator"):
......@@ -833,9 +833,6 @@ def filter_pipeline(payload, user):
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "title" in payload:
del payload["title"]
if "task" in payload:
del payload["task"]
......@@ -1338,7 +1335,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"stream": False,
"max_tokens": 50,
"chat_id": form_data.get("chat_id", None),
"title": True,
"task": TASKS.TITLE_GENERATION,
}
log.debug(payload)
......@@ -1401,7 +1398,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 30,
"task": True,
"task": TASKS.QUERY_GENERATION,
}
print(payload)
......@@ -1468,7 +1465,7 @@ Message: """{{prompt}}"""
"stream": False,
"max_tokens": 4,
"chat_id": form_data.get("chat_id", None),
"task": True,
"task": TASKS.EMOJI_GENERATION,
}
log.debug(payload)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment