Commit e3e02e04 authored by Michael Poluektov's avatar Michael Poluektov
Browse files

refac: backend/main.py

parent f9e3c47d
import base64 import base64
import uuid import uuid
import subprocess
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo from authlib.oidc.core import UserInfo
from bs4 import BeautifulSoup
import json import json
import markdown
import time import time
import os import os
import sys import sys
...@@ -19,14 +16,11 @@ import shutil ...@@ -19,14 +16,11 @@ import shutil
import os import os
import uuid import uuid
import inspect import inspect
import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
...@@ -38,7 +32,6 @@ from starlette.responses import StreamingResponse, Response, RedirectResponse ...@@ -38,7 +32,6 @@ from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import sio, app as socket_app from apps.socket.main import sio, app as socket_app
from apps.ollama.main import ( from apps.ollama.main import (
app as ollama_app, app as ollama_app,
OpenAIChatCompletionForm,
get_all_models as get_ollama_models, get_all_models as get_ollama_models,
generate_openai_chat_completion as generate_ollama_chat_completion, generate_openai_chat_completion as generate_ollama_chat_completion,
) )
...@@ -56,14 +49,14 @@ from apps.webui.main import ( ...@@ -56,14 +49,14 @@ from apps.webui.main import (
get_pipe_models, get_pipe_models,
generate_function_chat_completion, generate_function_chat_completion,
) )
from apps.webui.internal.db import Session, SessionLocal from apps.webui.internal.db import Session
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional, Iterator, Generator, Union from typing import List, Optional
from apps.webui.models.auths import Auths from apps.webui.models.auths import Auths
from apps.webui.models.models import Models, ModelModel from apps.webui.models.models import Models
from apps.webui.models.tools import Tools from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions from apps.webui.models.functions import Functions
from apps.webui.models.users import Users from apps.webui.models.users import Users
...@@ -86,14 +79,12 @@ from utils.task import ( ...@@ -86,14 +79,12 @@ from utils.task import (
from utils.misc import ( from utils.misc import (
get_last_user_message, get_last_user_message,
add_or_update_system_message, add_or_update_system_message,
stream_message_template,
parse_duration, parse_duration,
) )
from apps.rag.utils import get_rag_context, rag_template from apps.rag.utils import get_rag_context, rag_template
from config import ( from config import (
CONFIG_DATA,
WEBUI_NAME, WEBUI_NAME,
WEBUI_URL, WEBUI_URL,
WEBUI_AUTH, WEBUI_AUTH,
...@@ -101,7 +92,6 @@ from config import ( ...@@ -101,7 +92,6 @@ from config import (
VERSION, VERSION,
CHANGELOG, CHANGELOG,
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
UPLOAD_DIR,
CACHE_DIR, CACHE_DIR,
STATIC_DIR, STATIC_DIR,
DEFAULT_LOCALE, DEFAULT_LOCALE,
...@@ -128,9 +118,8 @@ from config import ( ...@@ -128,9 +118,8 @@ from config import (
WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE, WEBUI_SESSION_COOKIE_SECURE,
AppConfig, AppConfig,
BACKEND_DIR,
DATABASE_URL,
) )
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook from utils.webhook import post_webhook
...@@ -355,121 +344,94 @@ async def get_function_call_response( ...@@ -355,121 +344,94 @@ async def get_function_call_response(
else: else:
content = response["choices"][0]["message"]["content"] content = response["choices"][0]["message"]["content"]
if content is None:
return None, None, False
# Parse the function response # Parse the function response
if content is not None: print(f"content: {content}")
print(f"content: {content}") result = json.loads(content)
result = json.loads(content) print(result)
print(result)
citation = None
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False citation = None
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr( if "name" not in result:
toolkit_module, "Valves" return None, None, False
):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(
**(valves if valves else {})
)
function = getattr(toolkit_module, result["name"]) # Call the function
function_result = None if tool_id in webui_app.state.TOOLS:
try: toolkit_module = webui_app.state.TOOLS[tool_id]
# Get the signature of the function else:
sig = inspect.signature(function) toolkit_module, _ = load_toolkit_module_by_id(tool_id)
params = result["parameters"] webui_app.state.TOOLS[tool_id] = toolkit_module
if "__user__" in sig.parameters: file_handler = False
# Call the function with the '__user__' parameter included # check if toolkit_module has file_handler self variable
__user__ = { if hasattr(toolkit_module, "file_handler"):
"id": user.id, file_handler = True
"email": user.email, print("file_handler: ", file_handler)
"name": user.name,
"role": user.role, if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
} valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
try:
if hasattr(toolkit_module, "UserValves"): function = getattr(toolkit_module, result["name"])
__user__["valves"] = toolkit_module.UserValves( function_result = None
**Tools.get_user_valves_by_id_and_user_id( try:
tool_id, user.id # Get the signature of the function
) sig = inspect.signature(function)
) params = result["parameters"]
except Exception as e:
print(e) # Extra parameters to be passed to the function
extra_params = {
params = {**params, "__user__": __user__} "__model__": model,
if "__messages__" in sig.parameters: "__id__": tool_id,
# Call the function with the '__messages__' parameter included "__messages__": messages,
params = { "__files__": files,
**params, "__event_emitter__": __event_emitter__,
"__messages__": messages, "__event_call__": __event_call__,
} }
if "__files__" in sig.parameters: # Add extra params in contained in function signature
# Call the function with the '__files__' parameter included for key, value in extra_params.items():
params = { if key in sig.parameters:
**params, params[key] = value
"__files__": files,
} if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
if "__model__" in sig.parameters: __user__ = {
# Call the function with the '__model__' parameter included "id": user.id,
params = { "email": user.email,
**params, "name": user.name,
"__model__": model, "role": user.role,
} }
if "__id__" in sig.parameters: try:
# Call the function with the '__id__' parameter included if hasattr(toolkit_module, "UserValves"):
params = { __user__["valves"] = toolkit_module.UserValves(
**params, **Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
"__id__": tool_id, )
}
if "__event_emitter__" in sig.parameters:
# Call the function with the '__event_emitter__' parameter included
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
# Call the function with the '__event_call__' parameter included
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e: except Exception as e:
print(e) print(e)
# Add the function result to the system prompt params = {**params, "__user__": __user__}
if function_result is not None:
return function_result, citation, file_handler if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result is not None:
return function_result, citation, file_handler
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
...@@ -484,87 +446,74 @@ async def chat_completion_functions_handler( ...@@ -484,87 +446,74 @@ async def chat_completion_functions_handler(
filter_ids = get_filter_function_ids(model) filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids: for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id) filter = Functions.get_function_by_id(filter_id)
if filter: if not filter:
if filter_id in webui_app.state.FUNCTIONS: continue
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try: if filter_id in webui_app.state.FUNCTIONS:
if hasattr(function_module, "inlet"): function_module = webui_app.state.FUNCTIONS[filter_id]
inlet = function_module.inlet else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Get the signature of the function # Check if the function has a file_handler variable
sig = inspect.signature(inlet) if hasattr(function_module, "file_handler"):
params = {"body": body} skip_files = function_module.file_handler
if "__user__" in sig.parameters: if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
__user__ = { valves = Functions.get_function_valves_by_id(filter_id)
"id": user.id, function_module.valves = function_module.Valves(
"email": user.email, **(valves if valves else {})
"name": user.name, )
"role": user.role,
} try:
if hasattr(function_module, "inlet"):
try: inlet = function_module.inlet
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves( # Get the signature of the function
**Functions.get_user_valves_by_id_and_user_id( sig = inspect.signature(inlet)
filter_id, user.id params = {"body": body}
)
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
) )
except Exception as e: )
print(e) except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
body = inlet(**params)
except Exception as e: params = {**params, "__user__": __user__}
print(f"Error: {e}")
raise e if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
body = inlet(**params)
except Exception as e:
print(f"Error: {e}")
raise e
if skip_files: if skip_files:
if "files" in body: if "files" in body:
...@@ -1220,86 +1169,73 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1220,86 +1169,73 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
for filter_id in filter_ids: for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id) filter = Functions.get_function_by_id(filter_id)
if filter: if not filter:
if filter_id in webui_app.state.FUNCTIONS: continue
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try: if filter_id in webui_app.state.FUNCTIONS:
if hasattr(function_module, "outlet"): function_module = webui_app.state.FUNCTIONS[filter_id]
outlet = function_module.outlet else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Get the signature of the function if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
sig = inspect.signature(outlet) valves = Functions.get_function_valves_by_id(filter_id)
params = {"body": data} function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if "__user__" in sig.parameters: try:
__user__ = { if hasattr(function_module, "outlet"):
"id": user.id, outlet = function_module.outlet
"email": user.email,
"name": user.name, # Get the signature of the function
"role": user.role, sig = inspect.signature(outlet)
} params = {"body": data}
try: # Extra parameters to be passed to the function
if hasattr(function_module, "UserValves"): extra_params = {
__user__["valves"] = function_module.UserValves( "__model__": model,
**Functions.get_user_valves_by_id_and_user_id( "__id__": filter_id,
filter_id, user.id "__event_emitter__": __event_emitter__,
) "__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
) )
except Exception as e: )
print(e) except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if "__event_call__" in sig.parameters:
params = {
**params,
"__event_call__": __event_call__,
}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e: params = {**params, "__user__": __user__}
print(f"Error: {e}")
return JSONResponse( if inspect.iscoroutinefunction(outlet):
status_code=status.HTTP_400_BAD_REQUEST, data = await outlet(**params)
content={"detail": str(e)}, else:
) data = outlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data return data
...@@ -1387,7 +1323,6 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -1387,7 +1323,6 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
model_id = task_model_id model_id = task_model_id
print(model_id) print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
...@@ -1456,7 +1391,6 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -1456,7 +1391,6 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
model_id = task_model_id model_id = task_model_id
print(model_id) print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
...@@ -1513,7 +1447,6 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): ...@@ -1513,7 +1447,6 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
model_id = task_model_id model_id = task_model_id
print(model_id) print(model_id)
model = app.state.MODELS[model_id]
template = ''' template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
...@@ -1583,7 +1516,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ ...@@ -1583,7 +1516,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try: try:
context, citation, file_handler = await get_function_call_response( context, _, _ = await get_function_call_response(
form_data["messages"], form_data["messages"],
form_data.get("files", []), form_data.get("files", []),
form_data["tool_id"], form_data["tool_id"],
...@@ -1647,6 +1580,7 @@ async def upload_pipeline( ...@@ -1647,6 +1580,7 @@ async def upload_pipeline(
os.makedirs(upload_folder, exist_ok=True) os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename) file_path = os.path.join(upload_folder, file.filename)
r = None
try: try:
# Save the uploaded file # Save the uploaded file
with open(file_path, "wb") as buffer: with open(file_path, "wb") as buffer:
...@@ -1670,7 +1604,9 @@ async def upload_pipeline( ...@@ -1670,7 +1604,9 @@ async def upload_pipeline(
print(f"Connection error: {e}") print(f"Connection error: {e}")
detail = "Pipeline not found" detail = "Pipeline not found"
status_code = status.HTTP_404_NOT_FOUND
if r is not None: if r is not None:
status_code = r.status_code
try: try:
res = r.json() res = r.json()
if "detail" in res: if "detail" in res:
...@@ -1679,7 +1615,7 @@ async def upload_pipeline( ...@@ -1679,7 +1615,7 @@ async def upload_pipeline(
pass pass
raise HTTPException( raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND), status_code=status_code,
detail=detail, detail=detail,
) )
finally: finally:
...@@ -1778,8 +1714,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_ ...@@ -1778,8 +1714,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)): async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
r = None r = None
try: try:
urlIdx
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment