Unverified Commit 1eebb85f authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3323 from open-webui/dev

0.3.6
parents 9e4dd4b8 b224ba00
...@@ -167,6 +167,12 @@ for version in soup.find_all("h2"): ...@@ -167,6 +167,12 @@ for version in soup.find_all("h2"):
CHANGELOG = changelog_json CHANGELOG = changelog_json
####################################
# SAFE_MODE
####################################
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
#################################### ####################################
# WEBUI_BUILD_HASH # WEBUI_BUILD_HASH
#################################### ####################################
...@@ -299,6 +305,135 @@ JWT_EXPIRES_IN = PersistentConfig( ...@@ -299,6 +305,135 @@ JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
) )
####################################
# OAuth config
####################################
ENABLE_OAUTH_SIGNUP = PersistentConfig(
"ENABLE_OAUTH_SIGNUP",
"oauth.enable_signup",
os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true",
)
OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig(
"OAUTH_MERGE_ACCOUNTS_BY_EMAIL",
"oauth.merge_accounts_by_email",
os.environ.get("OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "False").lower() == "true",
)
OAUTH_PROVIDERS = {}
GOOGLE_CLIENT_ID = PersistentConfig(
"GOOGLE_CLIENT_ID",
"oauth.google.client_id",
os.environ.get("GOOGLE_CLIENT_ID", ""),
)
GOOGLE_CLIENT_SECRET = PersistentConfig(
"GOOGLE_CLIENT_SECRET",
"oauth.google.client_secret",
os.environ.get("GOOGLE_CLIENT_SECRET", ""),
)
GOOGLE_OAUTH_SCOPE = PersistentConfig(
"GOOGLE_OAUTH_SCOPE",
"oauth.google.scope",
os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
)
MICROSOFT_CLIENT_ID = PersistentConfig(
"MICROSOFT_CLIENT_ID",
"oauth.microsoft.client_id",
os.environ.get("MICROSOFT_CLIENT_ID", ""),
)
MICROSOFT_CLIENT_SECRET = PersistentConfig(
"MICROSOFT_CLIENT_SECRET",
"oauth.microsoft.client_secret",
os.environ.get("MICROSOFT_CLIENT_SECRET", ""),
)
MICROSOFT_CLIENT_TENANT_ID = PersistentConfig(
"MICROSOFT_CLIENT_TENANT_ID",
"oauth.microsoft.tenant_id",
os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""),
)
MICROSOFT_OAUTH_SCOPE = PersistentConfig(
"MICROSOFT_OAUTH_SCOPE",
"oauth.microsoft.scope",
os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
)
OAUTH_CLIENT_ID = PersistentConfig(
"OAUTH_CLIENT_ID",
"oauth.oidc.client_id",
os.environ.get("OAUTH_CLIENT_ID", ""),
)
OAUTH_CLIENT_SECRET = PersistentConfig(
"OAUTH_CLIENT_SECRET",
"oauth.oidc.client_secret",
os.environ.get("OAUTH_CLIENT_SECRET", ""),
)
OPENID_PROVIDER_URL = PersistentConfig(
"OPENID_PROVIDER_URL",
"oauth.oidc.provider_url",
os.environ.get("OPENID_PROVIDER_URL", ""),
)
OAUTH_SCOPES = PersistentConfig(
"OAUTH_SCOPES",
"oauth.oidc.scopes",
os.environ.get("OAUTH_SCOPES", "openid email profile"),
)
OAUTH_PROVIDER_NAME = PersistentConfig(
"OAUTH_PROVIDER_NAME",
"oauth.oidc.provider_name",
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()
if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value:
OAUTH_PROVIDERS["google"] = {
"client_id": GOOGLE_CLIENT_ID.value,
"client_secret": GOOGLE_CLIENT_SECRET.value,
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
"scope": GOOGLE_OAUTH_SCOPE.value,
}
if (
MICROSOFT_CLIENT_ID.value
and MICROSOFT_CLIENT_SECRET.value
and MICROSOFT_CLIENT_TENANT_ID.value
):
OAUTH_PROVIDERS["microsoft"] = {
"client_id": MICROSOFT_CLIENT_ID.value,
"client_secret": MICROSOFT_CLIENT_SECRET.value,
"server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration",
"scope": MICROSOFT_OAUTH_SCOPE.value,
}
if (
OAUTH_CLIENT_ID.value
and OAUTH_CLIENT_SECRET.value
and OPENID_PROVIDER_URL.value
):
OAUTH_PROVIDERS["oidc"] = {
"client_id": OAUTH_CLIENT_ID.value,
"client_secret": OAUTH_CLIENT_SECRET.value,
"server_metadata_url": OPENID_PROVIDER_URL.value,
"scope": OAUTH_SCOPES.value,
"name": OAUTH_PROVIDER_NAME.value,
}
load_oauth_providers()
#################################### ####################################
# Static DIR # Static DIR
#################################### ####################################
...@@ -377,6 +512,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") ...@@ -377,6 +512,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Functions DIR
####################################
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# LITELLM_CONFIG # LITELLM_CONFIG
#################################### ####################################
...@@ -426,12 +569,15 @@ OLLAMA_API_BASE_URL = os.environ.get( ...@@ -426,12 +569,15 @@ OLLAMA_API_BASE_URL = os.environ.get(
) )
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "300") AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
if AIOHTTP_CLIENT_TIMEOUT == "": if AIOHTTP_CLIENT_TIMEOUT == "":
AIOHTTP_CLIENT_TIMEOUT = None AIOHTTP_CLIENT_TIMEOUT = None
else: else:
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT) try:
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
except:
AIOHTTP_CLIENT_TIMEOUT = 300
K8S_FLAG = os.environ.get("K8S_FLAG", "") K8S_FLAG = os.environ.get("K8S_FLAG", "")
...@@ -719,6 +865,16 @@ WEBUI_SECRET_KEY = os.environ.get( ...@@ -719,6 +865,16 @@ WEBUI_SECRET_KEY = os.environ.get(
), # DEPRECATED: remove at next major version ), # DEPRECATED: remove at next major version
) )
WEBUI_SESSION_COOKIE_SAME_SITE = os.environ.get(
"WEBUI_SESSION_COOKIE_SAME_SITE",
os.environ.get("WEBUI_SESSION_COOKIE_SAME_SITE", "lax"),
)
WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
"WEBUI_SESSION_COOKIE_SECURE",
os.environ.get("WEBUI_SESSION_COOKIE_SECURE", "false").lower() == "true",
)
if WEBUI_AUTH and WEBUI_SECRET_KEY == "": if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
...@@ -903,6 +1059,18 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig( ...@@ -903,6 +1059,18 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
os.getenv("RAG_WEB_SEARCH_ENGINE", ""), os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
) )
# You can provide a list of your own websites to filter after performing a web search.
# This ensures the highest level of safety and reliability of the information sources.
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
"rag.rag.web.search.domain.filter_list",
[
# "wikipedia.com",
# "wikimedia.org",
# "wikidata.org",
],
)
SEARXNG_QUERY_URL = PersistentConfig( SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL", "SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url", "rag.web.search.searxng_query_url",
...@@ -1001,6 +1169,11 @@ AUTOMATIC1111_BASE_URL = PersistentConfig( ...@@ -1001,6 +1169,11 @@ AUTOMATIC1111_BASE_URL = PersistentConfig(
"image_generation.automatic1111.base_url", "image_generation.automatic1111.base_url",
os.getenv("AUTOMATIC1111_BASE_URL", ""), os.getenv("AUTOMATIC1111_BASE_URL", ""),
) )
AUTOMATIC1111_API_AUTH = PersistentConfig(
"AUTOMATIC1111_API_AUTH",
"image_generation.automatic1111.api_auth",
os.getenv("AUTOMATIC1111_API_AUTH", ""),
)
COMFYUI_BASE_URL = PersistentConfig( COMFYUI_BASE_URL = PersistentConfig(
"COMFYUI_BASE_URL", "COMFYUI_BASE_URL",
......
import base64
import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
import json import json
import markdown import markdown
...@@ -11,9 +16,11 @@ import requests ...@@ -11,9 +16,11 @@ import requests
import mimetypes import mimetypes
import shutil import shutil
import os import os
import uuid
import inspect import inspect
import asyncio import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
...@@ -22,7 +29,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware ...@@ -22,7 +29,8 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse, Response from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import app as socket_app from apps.socket.main import app as socket_app
...@@ -41,29 +49,43 @@ from apps.openai.main import ( ...@@ -41,29 +49,43 @@ from apps.openai.main import (
from apps.audio.main import app as audio_app from apps.audio.main import app as audio_app
from apps.images.main import app as images_app from apps.images.main import app as images_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
from apps.webui.main import app as webui_app from apps.webui.main import (
app as webui_app,
get_pipe_models,
generate_function_chat_completion,
)
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional, Iterator, Generator, Union
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models, ModelModel from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools from apps.webui.models.tools import Tools
from apps.webui.utils import load_toolkit_module_by_id from apps.webui.models.functions import Functions
from apps.webui.models.users import Users
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
from utils.utils import ( from utils.utils import (
get_admin_user, get_admin_user,
get_verified_user, get_verified_user,
get_current_user, get_current_user,
get_http_authorization_cred, get_http_authorization_cred,
get_password_hash,
create_token,
) )
from utils.task import ( from utils.task import (
title_generation_template, title_generation_template,
search_query_generation_template, search_query_generation_template,
tools_function_calling_generation_template, tools_function_calling_generation_template,
) )
from utils.misc import get_last_user_message, add_or_update_system_message from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
parse_duration,
)
from apps.rag.utils import get_rag_context, rag_template from apps.rag.utils import get_rag_context, rag_template
...@@ -76,6 +98,7 @@ from config import ( ...@@ -76,6 +98,7 @@ from config import (
VERSION, VERSION,
CHANGELOG, CHANGELOG,
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
UPLOAD_DIR,
CACHE_DIR, CACHE_DIR,
STATIC_DIR, STATIC_DIR,
ENABLE_OPENAI_API, ENABLE_OPENAI_API,
...@@ -93,9 +116,22 @@ from config import ( ...@@ -93,9 +116,22 @@ from config import (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
SAFE_MODE,
OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
WEBUI_SECRET_KEY,
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
AppConfig, AppConfig,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from utils.webhook import post_webhook
if SAFE_MODE:
print("SAFE MODE ENABLED")
Functions.deactivate_all_functions()
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -168,7 +204,16 @@ app.state.MODELS = {} ...@@ -168,7 +204,16 @@ app.state.MODELS = {}
origins = ["*"] origins = ["*"]
async def get_function_call_response(messages, tool_id, template, task_model_id, user): ##################################
#
# ChatCompletion Middleware
#
##################################
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user
):
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs) content = tools_function_calling_generation_template(template, tools_specs)
...@@ -205,12 +250,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, ...@@ -205,12 +250,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
response = None response = None
try: try:
if model["owned_by"] == "ollama": response = await generate_chat_completions(form_data=payload, user=user)
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
content = None content = None
...@@ -231,84 +271,241 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, ...@@ -231,84 +271,241 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
result = json.loads(content) result = json.loads(content)
print(result) print(result)
citation = None
# Call the function # Call the function
if "name" in result: if "name" in result:
if tool_id in webui_app.state.TOOLS: if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id] toolkit_module = webui_app.state.TOOLS[tool_id]
else: else:
toolkit_module = load_toolkit_module_by_id(tool_id) toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(
toolkit_module, "Valves"
):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(
**(valves if valves else {})
)
function = getattr(toolkit_module, result["name"]) function = getattr(toolkit_module, result["name"])
function_result = None function_result = None
try: try:
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(function) sig = inspect.signature(function)
# Check if '__user__' is a parameter of the function params = result["parameters"]
if "__user__" in sig.parameters: if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included # Call the function with the '__user__' parameter included
function_result = function( __user__ = {
**{ "id": user.id,
**result["parameters"], "email": user.email,
"__user__": { "name": user.name,
"id": user.id, "role": user.role,
"email": user.email, }
"name": user.name,
"role": user.role, try:
}, if hasattr(toolkit_module, "UserValves"):
} __user__["valves"] = toolkit_module.UserValves(
) **Tools.get_user_valves_by_id_and_user_id(
tool_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__messages__" in sig.parameters:
# Call the function with the '__messages__' parameter included
params = {
**params,
"__messages__": messages,
}
if "__files__" in sig.parameters:
# Call the function with the '__files__' parameter included
params = {
**params,
"__files__": files,
}
if "__model__" in sig.parameters:
# Call the function with the '__model__' parameter included
params = {
**params,
"__model__": model,
}
if "__id__" in sig.parameters:
# Call the function with the '__id__' parameter included
params = {
**params,
"__id__": tool_id,
}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else: else:
# Call the function without modifying the parameters function_result = function(**params)
function_result = function(**result["parameters"])
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e: except Exception as e:
print(e) print(e)
# Add the function result to the system prompt # Add the function result to the system prompt
if function_result: if function_result is not None:
return function_result return function_result, citation, file_handler
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
return None return None, None, False
class ChatCompletionMiddleware(BaseHTTPMiddleware): class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
return_citations = False data_items = []
if request.method == "POST" and ( show_citations = False
"/ollama/api/chat" in request.url.path citations = []
or "/chat/completions" in request.url.path
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
): ):
log.debug(f"request.url.path: {request.url.path}") log.debug(f"request.url.path: {request.url.path}")
# Read the original request body # Read the original request body
body = await request.body() body = await request.body()
# Decode body to string
body_str = body.decode("utf-8") body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
user = get_current_user( user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization")) request,
get_http_authorization_cred(request.headers.get("Authorization")),
) )
# Flag to skip RAG completions if file_handler is present in tools/functions
# Remove the citations from the body skip_files = False
return_citations = data.get("citations", False) if data.get("citations"):
if "citations" in data: show_citations = True
del data["citations"] del data["citations"]
# Set the task model model_id = data["model"]
task_model_id = data["model"] if model_id not in app.state.MODELS:
if task_model_id not in app.state.MODELS:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
) )
model = app.state.MODELS[model_id]
# Check if the user has a custom task model def get_priority(function_id):
# If the user has a custom task model, use that model function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get(
"priority", 0
)
return 0
filter_ids = [
function.id for function in Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type(
"filter", active_only=True
)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try:
if hasattr(function_module, "inlet"):
inlet = function_module.inlet
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if inspect.iscoroutinefunction(inlet):
data = await inlet(**params)
else:
data = inlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Set the task model
task_model_id = data["model"]
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if ( if (
app.state.config.TASK_MODEL app.state.config.TASK_MODEL
...@@ -331,55 +528,71 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -331,55 +528,71 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
for tool_id in data["tool_ids"]: for tool_id in data["tool_ids"]:
print(tool_id) print(tool_id)
try: try:
response = await get_function_call_response( response, citation, file_handler = (
messages=data["messages"], await get_function_call_response(
tool_id=tool_id, messages=data["messages"],
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, files=data.get("files", []),
task_model_id=task_model_id, tool_id=tool_id,
user=user, template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
) )
if response: print(file_handler)
if isinstance(response, str):
context += ("\n" if context != "" else "") + response context += ("\n" if context != "" else "") + response
if citation:
citations.append(citation)
show_citations = True
if file_handler:
skip_files = True
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
del data["tool_ids"] del data["tool_ids"]
print(f"tool_context: {context}") print(f"tool_context: {context}")
# If docs field is present, generate RAG completions # If files field is present, generate RAG completions
if "docs" in data: # If skip_files is True, skip the RAG completions
data = {**data} if "files" in data:
rag_context, citations = get_rag_context( if not skip_files:
docs=data["docs"], data = {**data}
messages=data["messages"], rag_context, rag_citations = get_rag_context(
embedding_function=rag_app.state.EMBEDDING_FUNCTION, files=data["files"],
k=rag_app.state.config.TOP_K, messages=data["messages"],
reranking_function=rag_app.state.sentence_transformer_rf, embedding_function=rag_app.state.EMBEDDING_FUNCTION,
r=rag_app.state.config.RELEVANCE_THRESHOLD, k=rag_app.state.config.TOP_K,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, reranking_function=rag_app.state.sentence_transformer_rf,
) r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if rag_context: if rag_citations:
context += ("\n" if context != "" else "") + rag_context citations.extend(rag_citations)
del data["docs"] del data["files"]
log.debug(f"rag_context: {rag_context}, citations: {citations}") if show_citations and len(citations) > 0:
data_items.append({"citations": citations})
if context != "": if context != "":
system_prompt = rag_template( system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt rag_app.state.config.RAG_TEMPLATE, context, prompt
) )
print(system_prompt) print(system_prompt)
data["messages"] = add_or_update_system_message( data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"] system_prompt, data["messages"]
) )
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
request._body = modified_body_bytes request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length # Set custom header to ensure content-length matches new body length
...@@ -392,43 +605,54 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -392,43 +605,54 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
], ],
] ]
response = await call_next(request) response = await call_next(request)
if return_citations:
# Inject the citations into the response
if isinstance(response, StreamingResponse): if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line # If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type") content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type: if "text/event-stream" in content_type:
return StreamingResponse( return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, citations), self.openai_stream_wrapper(response.body_iterator, data_items),
) )
if "application/x-ndjson" in content_type: if "application/x-ndjson" in content_type:
return StreamingResponse( return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, citations), self.ollama_stream_wrapper(response.body_iterator, data_items),
) )
else:
return response
# If it's not a chat completion request, just pass it through
response = await call_next(request)
return response return response
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False} return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, citations): async def openai_stream_wrapper(self, original_generator, data_items):
yield f"data: {json.dumps({'citations': citations})}\n\n" for item in data_items:
yield f"data: {json.dumps(item)}\n\n"
async for data in original_generator: async for data in original_generator:
yield data yield data
async def ollama_stream_wrapper(self, original_generator, citations): async def ollama_stream_wrapper(self, original_generator, data_items):
yield f"{json.dumps({'citations': citations})}\n" for item in data_items:
yield f"{json.dumps(item)}\n"
async for data in original_generator: async for data in original_generator:
yield data yield data
app.add_middleware(ChatCompletionMiddleware) app.add_middleware(ChatCompletionMiddleware)
##################################
#
# Pipeline Middleware
#
##################################
def filter_pipeline(payload, user): def filter_pipeline(payload, user):
user = {"id": user.id, "name": user.name, "role": user.role} user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"] model_id = payload["model"]
filters = [ filters = [
model model
...@@ -516,7 +740,8 @@ class PipelineMiddleware(BaseHTTPMiddleware): ...@@ -516,7 +740,8 @@ class PipelineMiddleware(BaseHTTPMiddleware):
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
user = get_current_user( user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization")) request,
get_http_authorization_cred(request.headers.get("Authorization")),
) )
try: try:
...@@ -584,7 +809,6 @@ async def update_embedding_function(request: Request, call_next): ...@@ -584,7 +809,6 @@ async def update_embedding_function(request: Request, call_next):
app.mount("/ws", socket_app) app.mount("/ws", socket_app)
app.mount("/ollama", ollama_app) app.mount("/ollama", ollama_app)
app.mount("/openai", openai_app) app.mount("/openai", openai_app)
...@@ -598,17 +822,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION ...@@ -598,17 +822,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models(): async def get_all_models():
pipe_models = []
openai_models = [] openai_models = []
ollama_models = [] ollama_models = []
pipe_models = await get_pipe_models()
if app.state.config.ENABLE_OPENAI_API: if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models() openai_models = await get_openai_models()
openai_models = openai_models["data"] openai_models = openai_models["data"]
if app.state.config.ENABLE_OLLAMA_API: if app.state.config.ENABLE_OLLAMA_API:
ollama_models = await get_ollama_models() ollama_models = await get_ollama_models()
ollama_models = [ ollama_models = [
{ {
"id": model["model"], "id": model["model"],
...@@ -621,9 +846,9 @@ async def get_all_models(): ...@@ -621,9 +846,9 @@ async def get_all_models():
for model in ollama_models["models"] for model in ollama_models["models"]
] ]
models = openai_models + ollama_models models = pipe_models + openai_models + ollama_models
custom_models = Models.get_all_models()
custom_models = Models.get_all_models()
for custom_model in custom_models: for custom_model in custom_models:
if custom_model.base_model_id == None: if custom_model.base_model_id == None:
for model in models: for model in models:
...@@ -686,6 +911,200 @@ async def get_models(user=Depends(get_verified_user)): ...@@ -686,6 +911,200 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models} return {"data": models}
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
pipe = model.get("pipe")
if pipe:
return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
return await generate_openai_chat_completion(form_data, user=user)
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {
"id": user.id,
"name": user.name,
"email": user.email,
"role": user.role,
},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
# Sort filter_ids by priority, using the get_priority function
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try:
if hasattr(function_module, "outlet"):
outlet = function_module.outlet
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data
##################################
#
# Task Endpoints
#
##################################
# TODO: Refactor task API endpoints below into a separate file
@app.get("/api/task/config") @app.get("/api/task/config")
async def get_task_config(user=Depends(get_verified_user)): async def get_task_config(user=Depends(get_verified_user)):
return { return {
...@@ -791,12 +1210,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -791,12 +1210,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if model["owned_by"] == "ollama": return await generate_chat_completions(form_data=payload, user=user)
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/query/completions") @app.post("/api/task/query/completions")
...@@ -856,12 +1270,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -856,12 +1270,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if model["owned_by"] == "ollama": return await generate_chat_completions(form_data=payload, user=user)
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/emoji/completions") @app.post("/api/task/emoji/completions")
...@@ -925,12 +1334,7 @@ Message: """{{prompt}}""" ...@@ -925,12 +1334,7 @@ Message: """{{prompt}}"""
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if model["owned_by"] == "ollama": return await generate_chat_completions(form_data=payload, user=user)
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/tools/completions") @app.post("/api/task/tools/completions")
...@@ -961,8 +1365,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ ...@@ -961,8 +1365,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try: try:
context = await get_function_call_response( context, citation, file_handler = await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],
template,
model_id,
user,
) )
return context return context
except Exception as e: except Exception as e:
...@@ -972,94 +1381,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ ...@@ -972,94 +1381,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
) )
@app.post("/api/chat/completions") ##################################
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): #
model_id = form_data["model"] # Pipelines Endpoints
if model_id not in app.state.MODELS: #
raise HTTPException( ##################################
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**form_data), user=user
)
else:
return await generate_openai_chat_completion(form_data, user=user)
# TODO: Refactor pipelines API endpoints below into a separate file
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
print(model_id)
if model_id in app.state.MODELS:
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
return data
@app.get("/api/pipelines/list") @app.get("/api/pipelines/list")
...@@ -1382,6 +1711,13 @@ async def update_pipeline_valves( ...@@ -1382,6 +1711,13 @@ async def update_pipeline_valves(
) )
##################################
#
# Config Endpoints
#
##################################
@app.get("/api/config") @app.get("/api/config")
async def get_app_config(): async def get_app_config():
# Checking and Handling the Absence of 'ui' in CONFIG_DATA # Checking and Handling the Absence of 'ui' in CONFIG_DATA
...@@ -1416,6 +1752,12 @@ async def get_app_config(): ...@@ -1416,6 +1752,12 @@ async def get_app_config():
"engine": audio_app.state.config.STT_ENGINE, "engine": audio_app.state.config.STT_ENGINE,
}, },
}, },
"oauth": {
"providers": {
name: config.get("name", name)
for name, config in OAUTH_PROVIDERS.items()
}
},
} }
...@@ -1445,6 +1787,9 @@ async def update_model_filter_config( ...@@ -1445,6 +1787,9 @@ async def update_model_filter_config(
} }
# TODO: webhook endpoint should be under config endpoints
@app.get("/api/webhook") @app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)): async def get_webhook_url(user=Depends(get_admin_user)):
return { return {
...@@ -1494,6 +1839,154 @@ async def get_app_latest_release_version(): ...@@ -1494,6 +1839,154 @@ async def get_app_latest_release_version():
) )
############################
# OAuth Login & Callback
############################
oauth = OAuth()
for provider_name, provider_config in OAUTH_PROVIDERS.items():
oauth.register(
name=provider_name,
client_id=provider_config["client_id"],
client_secret=provider_config["client_secret"],
server_metadata_url=provider_config["server_metadata_url"],
client_kwargs={
"scope": provider_config["scope"],
},
)
# SessionMiddleware is used by authlib for oauth
if len(OAUTH_PROVIDERS) > 0:
app.add_middleware(
SessionMiddleware,
secret_key=WEBUI_SECRET_KEY,
session_cookie="oui-session",
same_site=WEBUI_SESSION_COOKIE_SAME_SITE,
https_only=WEBUI_SESSION_COOKIE_SECURE,
)
@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
redirect_uri = request.url_for("oauth_callback", provider=provider)
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
# OAuth login logic is as follows:
# 1. Attempt to find a user with matching subject ID, tied to the provider
# 2. If OAUTH_MERGE_ACCOUNTS_BY_EMAIL is true, find a user with the email address provided via OAuth
# - This is considered insecure in general, as OAuth providers do not always verify email addresses
# 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user
# - Email addresses are considered unique, so we fail registration if the email address is alreayd taken
@app.get("/oauth/{provider}/callback")
async def oauth_callback(provider: str, request: Request, response: Response):
if provider not in OAUTH_PROVIDERS:
raise HTTPException(404)
client = oauth.create_client(provider)
try:
token = await client.authorize_access_token(request)
except Exception as e:
log.warning(f"OAuth callback error: {e}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
user_data: UserInfo = token["userinfo"]
sub = user_data.get("sub")
if not sub:
log.warning(f"OAuth callback failed, sub is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
provider_sub = f"{provider}@{sub}"
email = user_data.get("email", "").lower()
# We currently mandate that email addresses are provided
if not email:
log.warning(f"OAuth callback failed, email is missing: {user_data}")
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
if not user:
# If the user does not exist, check if merging is enabled
if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value:
# Check if the user exists by email
user = Users.get_user_by_email(email)
if user:
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if not user:
# If the user does not exist, check if signups are enabled
if ENABLE_OAUTH_SIGNUP.value:
# Check if an existing user with the same email already exists
existing_user = Users.get_user_by_email(user_data.get("email", "").lower())
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_url = user_data.get("picture", "")
if picture_url:
# Download the profile image into a base64 string
try:
async with aiohttp.ClientSession() as session:
async with session.get(picture_url) as resp:
picture = await resp.read()
base64_encoded_picture = base64.b64encode(picture).decode(
"utf-8"
)
guessed_mime_type = mimetypes.guess_type(picture_url)[0]
if guessed_mime_type is None:
# assume JPG, browsers are tolerant enough of image formats
guessed_mime_type = "image/jpeg"
picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}"
except Exception as e:
log.error(f"Error downloading profile image '{picture_url}': {e}")
picture_url = ""
if not picture_url:
picture_url = "/user.png"
user = Auths.insert_new_auth(
email=email,
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
profile_image_url=picture_url,
role=webui_app.state.config.DEFAULT_USER_ROLE,
oauth_sub=provider_sub,
)
if webui_app.state.config.WEBHOOK_URL:
post_webhook(
webui_app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
else:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
jwt_token = create_token(
data={"id": user.id},
expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN),
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
)
# Redirect back to the frontend with the JWT token
redirect_url = f"{request.base_url}auth#token={jwt_token}"
return RedirectResponse(url=redirect_url)
@app.get("/manifest.json") @app.get("/manifest.json")
async def get_manifest_json(): async def get_manifest_json():
return { return {
...@@ -1502,7 +1995,6 @@ async def get_manifest_json(): ...@@ -1502,7 +1995,6 @@ async def get_manifest_json():
"start_url": "/", "start_url": "/",
"display": "standalone", "display": "standalone",
"background_color": "#343541", "background_color": "#343541",
"theme_color": "#343541",
"orientation": "portrait-primary", "orientation": "portrait-primary",
"icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}], "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
} }
......
...@@ -17,11 +17,17 @@ peewee-migrate==1.12.2 ...@@ -17,11 +17,17 @@ peewee-migrate==1.12.2
psycopg2-binary==2.9.9 psycopg2-binary==2.9.9
PyMySQL==1.1.1 PyMySQL==1.1.1
bcrypt==4.1.3 bcrypt==4.1.3
SQLAlchemy
pymongo
redis
boto3==1.34.110 boto3==1.34.110
argon2-cffi==23.1.0 argon2-cffi==23.1.0
APScheduler==3.10.4 APScheduler==3.10.4
# AI libraries
openai
anthropic
google-generativeai==0.5.4 google-generativeai==0.5.4
langchain==0.2.0 langchain==0.2.0
...@@ -52,6 +58,7 @@ rank-bm25==0.2.2 ...@@ -52,6 +58,7 @@ rank-bm25==0.2.2
faster-whisper==1.0.2 faster-whisper==1.0.2
PyJWT[crypto]==2.8.0 PyJWT[crypto]==2.8.0
authlib==1.3.0
black==24.4.2 black==24.4.2
langfuse==2.33.0 langfuse==2.33.0
......
...@@ -3,7 +3,9 @@ import hashlib ...@@ -3,7 +3,9 @@ import hashlib
import json import json
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional, List from typing import Optional, List, Tuple
import uuid
import time
def get_last_user_message(messages: List[dict]) -> str: def get_last_user_message(messages: List[dict]) -> str:
...@@ -28,6 +30,21 @@ def get_last_assistant_message(messages: List[dict]) -> str: ...@@ -28,6 +30,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
return None return None
def get_system_message(messages: List[dict]) -> dict:
for message in messages:
if message["role"] == "system":
return message
return None
def remove_system_message(messages: List[dict]) -> List[dict]:
return [message for message in messages if message["role"] != "system"]
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
return get_system_message(messages), remove_system_message(messages)
def add_or_update_system_message(content: str, messages: List[dict]): def add_or_update_system_message(content: str, messages: List[dict]):
""" """
Adds a new system message at the beginning of the messages list Adds a new system message at the beginning of the messages list
...@@ -47,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]): ...@@ -47,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages return messages
def stream_message_template(model: str, message: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
}
def get_gravatar_url(email): def get_gravatar_url(email):
# Trim leading and trailing whitespace from # Trim leading and trailing whitespace from
# an email address and force all characters # an email address and force all characters
......
...@@ -24,10 +24,16 @@ def prompt_template( ...@@ -24,10 +24,16 @@ def prompt_template(
if user_name: if user_name:
# Replace {{USER_NAME}} in the template with the user's name # Replace {{USER_NAME}} in the template with the user's name
template = template.replace("{{USER_NAME}}", user_name) template = template.replace("{{USER_NAME}}", user_name)
else:
# Replace {{USER_NAME}} in the template with "Unknown"
template = template.replace("{{USER_NAME}}", "Unknown")
if user_location: if user_location:
# Replace {{USER_LOCATION}} in the template with the current location # Replace {{USER_LOCATION}} in the template with the current location
template = template.replace("{{USER_LOCATION}}", user_location) template = template.replace("{{USER_LOCATION}}", user_location)
else:
# Replace {{USER_LOCATION}} in the template with "Unknown"
template = template.replace("{{USER_LOCATION}}", "Unknown")
return template return template
......
...@@ -20,7 +20,9 @@ def get_tools_specs(tools) -> List[dict]: ...@@ -20,7 +20,9 @@ def get_tools_specs(tools) -> List[dict]:
function_list = [ function_list = [
{"name": func, "function": getattr(tools, func)} {"name": func, "function": getattr(tools, func)}
for func in dir(tools) for func in dir(tools)
if callable(getattr(tools, func)) and not func.startswith("__") if callable(getattr(tools, func))
and not func.startswith("__")
and not inspect.isclass(getattr(tools, func))
] ]
specs = [] specs = []
...@@ -65,6 +67,7 @@ def get_tools_specs(tools) -> List[dict]: ...@@ -65,6 +67,7 @@ def get_tools_specs(tools) -> List[dict]:
function function
).parameters.items() ).parameters.items()
if param.default is param.empty if param.default is param.empty
and not (name.startswith("__") and name.endswith("__"))
], ],
}, },
} }
......
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends from fastapi import HTTPException, status, Depends, Request
from apps.webui.models.users import Users from apps.webui.models.users import Users
...@@ -24,7 +24,7 @@ ALGORITHM = "HS256" ...@@ -24,7 +24,7 @@ ALGORITHM = "HS256"
# Auth Utils # Auth Utils
############## ##############
bearer_security = HTTPBearer() bearer_security = HTTPBearer(auto_error=False)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
...@@ -75,13 +75,26 @@ def get_http_authorization_cred(auth_header: str): ...@@ -75,13 +75,26 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user( def get_current_user(
request: Request,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
): ):
token = None
if auth_token is not None:
token = auth_token.credentials
if token is None and "token" in request.cookies:
token = request.cookies.get("token")
if token is None:
raise HTTPException(status_code=403, detail="Not authenticated")
# auth by api key # auth by api key
if auth_token.credentials.startswith("sk-"): if token.startswith("sk-"):
return get_current_user_by_api_key(auth_token.credentials) return get_current_user_by_api_key(token)
# auth by jwt token # auth by jwt token
data = decode_token(auth_token.credentials) data = decode_token(token)
if data != None and "id" in data: if data != None and "id" in data:
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
if user is None: if user is None:
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.3.5", "version": "0.3.6",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "open-webui", "name": "open-webui",
"version": "0.3.5", "version": "0.3.6",
"dependencies": { "dependencies": {
"@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-javascript": "^6.2.2",
"@codemirror/lang-python": "^6.1.6", "@codemirror/lang-python": "^6.1.6",
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"async": "^3.2.5", "async": "^3.2.5",
"bits-ui": "^0.19.7", "bits-ui": "^0.19.7",
"codemirror": "^6.0.1", "codemirror": "^6.0.1",
"crc-32": "^1.2.2",
"dayjs": "^1.11.10", "dayjs": "^1.11.10",
"eventsource-parser": "^1.1.2", "eventsource-parser": "^1.1.2",
"file-saver": "^2.0.5", "file-saver": "^2.0.5",
...@@ -28,11 +29,12 @@ ...@@ -28,11 +29,12 @@
"katex": "^0.16.9", "katex": "^0.16.9",
"marked": "^9.1.0", "marked": "^9.1.0",
"mermaid": "^10.9.1", "mermaid": "^10.9.1",
"pyodide": "^0.26.0-alpha.4", "pyodide": "^0.26.1",
"socket.io-client": "^4.7.5", "socket.io-client": "^4.2.0",
"sortablejs": "^1.15.2", "sortablejs": "^1.15.2",
"svelte-sonner": "^0.3.19", "svelte-sonner": "^0.3.19",
"tippy.js": "^6.3.7", "tippy.js": "^6.3.7",
"turndown": "^7.2.0",
"uuid": "^9.0.1" "uuid": "^9.0.1"
}, },
"devDependencies": { "devDependencies": {
...@@ -999,6 +1001,11 @@ ...@@ -999,6 +1001,11 @@
"svelte": ">=3 <5" "svelte": ">=3 <5"
} }
}, },
"node_modules/@mixmark-io/domino": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/@mixmark-io/domino/-/domino-2.2.0.tgz",
"integrity": "sha512-Y28PR25bHXUg88kCV7nivXrP2Nj2RueZ3/l/jdx6J9f8J4nsEGcgX0Qe6lt7Pa+J79+kPiJU3LguR6O/6zrLOw=="
},
"node_modules/@nodelib/fs.scandir": { "node_modules/@nodelib/fs.scandir": {
"version": "2.1.5", "version": "2.1.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
...@@ -2266,11 +2273,6 @@ ...@@ -2266,11 +2273,6 @@
"dev": true, "dev": true,
"optional": true "optional": true
}, },
"node_modules/base-64": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/base-64/-/base-64-1.0.0.tgz",
"integrity": "sha512-kwDPIFCGx0NZHog36dj+tHiwP4QMzsZ3AgMViUBKI0+V5n4U0ufTCUMhnQ04diaRI8EX/QcPfql7zlhZ7j4zgg=="
},
"node_modules/base64-js": { "node_modules/base64-js": {
"version": "1.5.1", "version": "1.5.1",
"resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz",
...@@ -3063,6 +3065,17 @@ ...@@ -3063,6 +3065,17 @@
"layout-base": "^1.0.0" "layout-base": "^1.0.0"
} }
}, },
"node_modules/crc-32": {
"version": "1.2.2",
"resolved": "https://registry.npmjs.org/crc-32/-/crc-32-1.2.2.tgz",
"integrity": "sha512-ROmzCKrTnOwybPcJApAA6WBWij23HVfGVNKqqrZpuyZOHqK2CwHSvpGuyt/UNNvaIjEd8X5IFGp4Mh+Ie1IHJQ==",
"bin": {
"crc32": "bin/crc32.njs"
},
"engines": {
"node": ">=0.8"
}
},
"node_modules/crelt": { "node_modules/crelt": {
"version": "1.0.6", "version": "1.0.6",
"resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz", "resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz",
...@@ -3984,37 +3997,17 @@ ...@@ -3984,37 +3997,17 @@
} }
}, },
"node_modules/engine.io-client": { "node_modules/engine.io-client": {
"version": "6.5.3", "version": "6.5.4",
"resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.5.3.tgz", "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.5.4.tgz",
"integrity": "sha512-9Z0qLB0NIisTRt1DZ/8U2k12RJn8yls/nXMZLn+/N8hANT3TcYjKFKcwbw5zFQiN4NTde3TSY9zb79e1ij6j9Q==", "integrity": "sha512-GeZeeRjpD2qf49cZQ0Wvh/8NJNfeXkXXcoGh+F77oEAgo9gUHwT1fCRxSNU+YEEaysOJTnsFHmM5oAcPy4ntvQ==",
"dependencies": { "dependencies": {
"@socket.io/component-emitter": "~3.1.0", "@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1", "debug": "~4.3.1",
"engine.io-parser": "~5.2.1", "engine.io-parser": "~5.2.1",
"ws": "~8.11.0", "ws": "~8.17.1",
"xmlhttprequest-ssl": "~2.0.0" "xmlhttprequest-ssl": "~2.0.0"
} }
}, },
"node_modules/engine.io-client/node_modules/ws": {
"version": "8.11.0",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz",
"integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==",
"engines": {
"node": ">=10.0.0"
},
"peerDependencies": {
"bufferutil": "^4.0.1",
"utf-8-validate": "^5.0.2"
},
"peerDependenciesMeta": {
"bufferutil": {
"optional": true
},
"utf-8-validate": {
"optional": true
}
}
},
"node_modules/engine.io-parser": { "node_modules/engine.io-parser": {
"version": "5.2.2", "version": "5.2.2",
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.2.tgz", "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.2.tgz",
...@@ -7551,11 +7544,10 @@ ...@@ -7551,11 +7544,10 @@
} }
}, },
"node_modules/pyodide": { "node_modules/pyodide": {
"version": "0.26.0-alpha.4", "version": "0.26.1",
"resolved": "https://registry.npmjs.org/pyodide/-/pyodide-0.26.0-alpha.4.tgz", "resolved": "https://registry.npmjs.org/pyodide/-/pyodide-0.26.1.tgz",
"integrity": "sha512-Ixuczq99DwhQlE+Bt0RaS6Ln9MHSZOkbU6iN8azwaeorjHtr7ukaxh+FeTxViFrp2y+ITyKgmcobY+JnBPcULw==", "integrity": "sha512-P+Gm88nwZqY7uBgjbQH8CqqU6Ei/rDn7pS1t02sNZsbyLJMyE2OVXjgNuqVT3KqYWnyGREUN0DbBUCJqk8R0ew==",
"dependencies": { "dependencies": {
"base-64": "^1.0.0",
"ws": "^8.5.0" "ws": "^8.5.0"
}, },
"engines": { "engines": {
...@@ -9065,6 +9057,14 @@ ...@@ -9065,6 +9057,14 @@
"node": "*" "node": "*"
} }
}, },
"node_modules/turndown": {
"version": "7.2.0",
"resolved": "https://registry.npmjs.org/turndown/-/turndown-7.2.0.tgz",
"integrity": "sha512-eCZGBN4nNNqM9Owkv9HAtWRYfLA4h909E/WGAWWBpmB275ehNhZyk87/Tpvjbp0jjNl9XwCsbe6bm6CqFsgD+A==",
"dependencies": {
"@mixmark-io/domino": "^2.2.0"
}
},
"node_modules/tweetnacl": { "node_modules/tweetnacl": {
"version": "0.14.5", "version": "0.14.5",
"resolved": "https://registry.npmjs.org/tweetnacl/-/tweetnacl-0.14.5.tgz", "resolved": "https://registry.npmjs.org/tweetnacl/-/tweetnacl-0.14.5.tgz",
...@@ -10382,9 +10382,9 @@ ...@@ -10382,9 +10382,9 @@
"integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==" "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ=="
}, },
"node_modules/ws": { "node_modules/ws": {
"version": "8.17.0", "version": "8.17.1",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.17.0.tgz", "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz",
"integrity": "sha512-uJq6108EgZMAl20KagGkzCKfMEjxmKvZHG7Tlq0Z6nOky7YF7aq4mOx6xK8TJ/i1LeK4Qus7INktacctDgY8Ow==", "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==",
"engines": { "engines": {
"node": ">=10.0.0" "node": ">=10.0.0"
}, },
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.3.5", "version": "0.3.6",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "npm run pyodide:fetch && vite dev --host", "dev": "npm run pyodide:fetch && vite dev --host",
...@@ -56,6 +56,7 @@ ...@@ -56,6 +56,7 @@
"async": "^3.2.5", "async": "^3.2.5",
"bits-ui": "^0.19.7", "bits-ui": "^0.19.7",
"codemirror": "^6.0.1", "codemirror": "^6.0.1",
"crc-32": "^1.2.2",
"dayjs": "^1.11.10", "dayjs": "^1.11.10",
"eventsource-parser": "^1.1.2", "eventsource-parser": "^1.1.2",
"file-saver": "^2.0.5", "file-saver": "^2.0.5",
...@@ -68,11 +69,12 @@ ...@@ -68,11 +69,12 @@
"katex": "^0.16.9", "katex": "^0.16.9",
"marked": "^9.1.0", "marked": "^9.1.0",
"mermaid": "^10.9.1", "mermaid": "^10.9.1",
"pyodide": "^0.26.0-alpha.4", "pyodide": "^0.26.1",
"socket.io-client": "^4.7.5", "socket.io-client": "^4.2.0",
"sortablejs": "^1.15.2", "sortablejs": "^1.15.2",
"svelte-sonner": "^0.3.19", "svelte-sonner": "^0.3.19",
"tippy.js": "^6.3.7", "tippy.js": "^6.3.7",
"turndown": "^7.2.0",
"uuid": "^9.0.1" "uuid": "^9.0.1"
} }
} }
...@@ -59,6 +59,7 @@ dependencies = [ ...@@ -59,6 +59,7 @@ dependencies = [
"faster-whisper==1.0.2", "faster-whisper==1.0.2",
"PyJWT[crypto]==2.8.0", "PyJWT[crypto]==2.8.0",
"authlib==1.3.0",
"black==24.4.2", "black==24.4.2",
"langfuse==2.33.0", "langfuse==2.33.0",
......
...@@ -31,6 +31,8 @@ asgiref==3.8.1 ...@@ -31,6 +31,8 @@ asgiref==3.8.1
# via opentelemetry-instrumentation-asgi # via opentelemetry-instrumentation-asgi
attrs==23.2.0 attrs==23.2.0
# via aiohttp # via aiohttp
authlib==1.3.0
# via open-webui
av==11.0.0 av==11.0.0
# via faster-whisper # via faster-whisper
backoff==2.2.1 backoff==2.2.1
...@@ -93,6 +95,7 @@ coloredlogs==15.0.1 ...@@ -93,6 +95,7 @@ coloredlogs==15.0.1
compressed-rtf==1.0.6 compressed-rtf==1.0.6
# via extract-msg # via extract-msg
cryptography==42.0.7 cryptography==42.0.7
# via authlib
# via msoffcrypto-tool # via msoffcrypto-tool
# via pyjwt # via pyjwt
ctranslate2==4.2.1 ctranslate2==4.2.1
...@@ -395,6 +398,7 @@ pandas==2.2.2 ...@@ -395,6 +398,7 @@ pandas==2.2.2
# via open-webui # via open-webui
passlib==1.7.4 passlib==1.7.4
# via open-webui # via open-webui
# via passlib
pathspec==0.12.1 pathspec==0.12.1
# via black # via black
pcodedmp==1.2.6 pcodedmp==1.2.6
...@@ -453,6 +457,7 @@ pygments==2.18.0 ...@@ -453,6 +457,7 @@ pygments==2.18.0
# via rich # via rich
pyjwt==2.8.0 pyjwt==2.8.0
# via open-webui # via open-webui
# via pyjwt
pymysql==1.1.0 pymysql==1.1.0
# via open-webui # via open-webui
pypandoc==1.13 pypandoc==1.13
...@@ -554,9 +559,6 @@ scipy==1.13.0 ...@@ -554,9 +559,6 @@ scipy==1.13.0
# via sentence-transformers # via sentence-transformers
sentence-transformers==2.7.0 sentence-transformers==2.7.0
# via open-webui # via open-webui
setuptools==69.5.1
# via ctranslate2
# via opentelemetry-instrumentation
shapely==2.0.4 shapely==2.0.4
# via rapidocr-onnxruntime # via rapidocr-onnxruntime
shellingham==1.5.4 shellingham==1.5.4
...@@ -651,6 +653,7 @@ uvicorn==0.22.0 ...@@ -651,6 +653,7 @@ uvicorn==0.22.0
# via chromadb # via chromadb
# via fastapi # via fastapi
# via open-webui # via open-webui
# via uvicorn
uvloop==0.19.0 uvloop==0.19.0
# via uvicorn # via uvicorn
validators==0.28.1 validators==0.28.1
...@@ -678,3 +681,6 @@ youtube-transcript-api==0.6.2 ...@@ -678,3 +681,6 @@ youtube-transcript-api==0.6.2
# via open-webui # via open-webui
zipp==3.18.1 zipp==3.18.1
# via importlib-metadata # via importlib-metadata
setuptools==69.5.1
# via ctranslate2
# via opentelemetry-instrumentation
...@@ -31,6 +31,8 @@ asgiref==3.8.1 ...@@ -31,6 +31,8 @@ asgiref==3.8.1
# via opentelemetry-instrumentation-asgi # via opentelemetry-instrumentation-asgi
attrs==23.2.0 attrs==23.2.0
# via aiohttp # via aiohttp
authlib==1.3.0
# via open-webui
av==11.0.0 av==11.0.0
# via faster-whisper # via faster-whisper
backoff==2.2.1 backoff==2.2.1
...@@ -93,6 +95,7 @@ coloredlogs==15.0.1 ...@@ -93,6 +95,7 @@ coloredlogs==15.0.1
compressed-rtf==1.0.6 compressed-rtf==1.0.6
# via extract-msg # via extract-msg
cryptography==42.0.7 cryptography==42.0.7
# via authlib
# via msoffcrypto-tool # via msoffcrypto-tool
# via pyjwt # via pyjwt
ctranslate2==4.2.1 ctranslate2==4.2.1
...@@ -395,6 +398,7 @@ pandas==2.2.2 ...@@ -395,6 +398,7 @@ pandas==2.2.2
# via open-webui # via open-webui
passlib==1.7.4 passlib==1.7.4
# via open-webui # via open-webui
# via passlib
pathspec==0.12.1 pathspec==0.12.1
# via black # via black
pcodedmp==1.2.6 pcodedmp==1.2.6
...@@ -453,6 +457,7 @@ pygments==2.18.0 ...@@ -453,6 +457,7 @@ pygments==2.18.0
# via rich # via rich
pyjwt==2.8.0 pyjwt==2.8.0
# via open-webui # via open-webui
# via pyjwt
pymysql==1.1.0 pymysql==1.1.0
# via open-webui # via open-webui
pypandoc==1.13 pypandoc==1.13
...@@ -554,9 +559,6 @@ scipy==1.13.0 ...@@ -554,9 +559,6 @@ scipy==1.13.0
# via sentence-transformers # via sentence-transformers
sentence-transformers==2.7.0 sentence-transformers==2.7.0
# via open-webui # via open-webui
setuptools==69.5.1
# via ctranslate2
# via opentelemetry-instrumentation
shapely==2.0.4 shapely==2.0.4
# via rapidocr-onnxruntime # via rapidocr-onnxruntime
shellingham==1.5.4 shellingham==1.5.4
...@@ -651,6 +653,7 @@ uvicorn==0.22.0 ...@@ -651,6 +653,7 @@ uvicorn==0.22.0
# via chromadb # via chromadb
# via fastapi # via fastapi
# via open-webui # via open-webui
# via uvicorn
uvloop==0.19.0 uvloop==0.19.0
# via uvicorn # via uvicorn
validators==0.28.1 validators==0.28.1
...@@ -678,3 +681,6 @@ youtube-transcript-api==0.6.2 ...@@ -678,3 +681,6 @@ youtube-transcript-api==0.6.2
# via open-webui # via open-webui
zipp==3.18.1 zipp==3.18.1
# via importlib-metadata # via importlib-metadata
setuptools==69.5.1
# via ctranslate2
# via opentelemetry-instrumentation
const packages = [ const packages = [
'micropip',
'packaging',
'requests', 'requests',
'beautifulsoup4', 'beautifulsoup4',
'numpy', 'numpy',
...@@ -11,20 +13,64 @@ const packages = [ ...@@ -11,20 +13,64 @@ const packages = [
]; ];
import { loadPyodide } from 'pyodide'; import { loadPyodide } from 'pyodide';
import { writeFile, copyFile, readdir } from 'fs/promises'; import { writeFile, readFile, copyFile, readdir, rmdir } from 'fs/promises';
async function downloadPackages() { async function downloadPackages() {
console.log('Setting up pyodide + micropip'); console.log('Setting up pyodide + micropip');
const pyodide = await loadPyodide({
packageCacheDir: 'static/pyodide' let pyodide;
}); try {
await pyodide.loadPackage('micropip'); pyodide = await loadPyodide({
const micropip = pyodide.pyimport('micropip'); packageCacheDir: 'static/pyodide'
console.log('Downloading Pyodide packages:', packages); });
await micropip.install(packages); } catch (err) {
console.log('Pyodide packages downloaded, freezing into lock file'); console.error('Failed to load Pyodide:', err);
const lockFile = await micropip.freeze(); return;
await writeFile('static/pyodide/pyodide-lock.json', lockFile); }
const packageJson = JSON.parse(await readFile('package.json'));
const pyodideVersion = packageJson.dependencies.pyodide.replace('^', '');
try {
const pyodidePackageJson = JSON.parse(await readFile('static/pyodide/package.json'));
const pyodidePackageVersion = pyodidePackageJson.version.replace('^', '');
if (pyodideVersion !== pyodidePackageVersion) {
console.log('Pyodide version mismatch, removing static/pyodide directory');
await rmdir('static/pyodide', { recursive: true });
}
} catch (e) {
console.log('Pyodide package not found, proceeding with download.');
}
try {
console.log('Loading micropip package');
await pyodide.loadPackage('micropip');
const micropip = pyodide.pyimport('micropip');
console.log('Downloading Pyodide packages:', packages);
try {
for (const pkg of packages) {
console.log(`Installing package: ${pkg}`);
await micropip.install(pkg);
}
} catch (err) {
console.error('Package installation failed:', err);
return;
}
console.log('Pyodide packages downloaded, freezing into lock file');
try {
const lockFile = await micropip.freeze();
await writeFile('static/pyodide/pyodide-lock.json', lockFile);
} catch (err) {
console.error('Failed to write lock file:', err);
}
} catch (err) {
console.error('Failed to load or install micropip:', err);
}
} }
async function copyPyodide() { async function copyPyodide() {
......
...@@ -32,6 +32,10 @@ math { ...@@ -32,6 +32,10 @@ math {
@apply underline; @apply underline;
} }
iframe {
@apply rounded-lg;
}
ol > li { ol > li {
counter-increment: list-number; counter-increment: list-number;
display: block; display: block;
......
...@@ -13,6 +13,12 @@ ...@@ -13,6 +13,12 @@
href="/opensearch.xml" href="/opensearch.xml"
/> />
<script>
function resizeIframe(obj) {
obj.style.height = obj.contentWindow.document.documentElement.scrollHeight + 'px';
}
</script>
<script> <script>
// On page load or when changing themes, best to add inline in `head` to avoid FOUC // On page load or when changing themes, best to add inline in `head` to avoid FOUC
(() => { (() => {
......
...@@ -90,7 +90,8 @@ export const getSessionUser = async (token: string) => { ...@@ -90,7 +90,8 @@ export const getSessionUser = async (token: string) => {
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
Authorization: `Bearer ${token}` Authorization: `Bearer ${token}`
} },
credentials: 'include'
}) })
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
...@@ -117,6 +118,7 @@ export const userSignIn = async (email: string, password: string) => { ...@@ -117,6 +118,7 @@ export const userSignIn = async (email: string, password: string) => {
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}, },
credentials: 'include',
body: JSON.stringify({ body: JSON.stringify({
email: email, email: email,
password: password password: password
...@@ -153,6 +155,7 @@ export const userSignUp = async ( ...@@ -153,6 +155,7 @@ export const userSignUp = async (
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}, },
credentials: 'include',
body: JSON.stringify({ body: JSON.stringify({
name: name, name: name,
email: email, email: email,
......
import { WEBUI_API_BASE_URL } from '$lib/constants';
export const uploadFile = async (token: string, file: File) => {
const data = new FormData();
data.append('file', file);
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/files/`, {
method: 'POST',
headers: {
Accept: 'application/json',
authorization: `Bearer ${token}`
},
body: data
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFiles = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/files/`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFileById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/files/${id}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFileContentById = async (id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/files/${id}/content`, {
method: 'GET',
headers: {
Accept: 'application/json'
},
credentials: 'include'
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return await res.blob();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const deleteFileById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/files/${id}`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const deleteAllFiles = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/files/all`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
import { WEBUI_API_BASE_URL } from '$lib/constants';
export const createNewFunction = async (token: string, func: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/create`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...func
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctions = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const exportFunctions = async (token: string = '') => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/export`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctionById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFunctionById = async (token: string, id: string, func: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...func
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const deleteFunctionById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/delete`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const toggleFunctionById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/toggle`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const toggleGlobalById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/toggle/global`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctionValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctionValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFunctionValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateUserValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
...@@ -164,6 +164,37 @@ export const updateQuerySettings = async (token: string, settings: QuerySettings ...@@ -164,6 +164,37 @@ export const updateQuerySettings = async (token: string, settings: QuerySettings
return res; return res;
}; };
export const processDocToVectorDB = async (token: string, file_id: string) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/process/doc`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
file_id: file_id
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => { export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => {
const data = new FormData(); const data = new FormData();
data.append('file', file); data.append('file', file);
......
...@@ -191,3 +191,201 @@ export const deleteToolById = async (token: string, id: string) => { ...@@ -191,3 +191,201 @@ export const deleteToolById = async (token: string, id: string) => {
return res; return res;
}; };
export const getToolValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getToolValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateToolValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateUserValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment