Commit f26d80dc authored by Jun Siang Cheah's avatar Jun Siang Cheah
Browse files

Merge remote-tracking branch 'upstream/dev' into feat/oauth

parents 99e7b328 f54a66b8
...@@ -77,6 +77,7 @@ from apps.rag.search.serpstack import search_serpstack ...@@ -77,6 +77,7 @@ from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo from apps.rag.search.duckduckgo import search_duckduckgo
from apps.rag.search.tavily import search_tavily from apps.rag.search.tavily import search_tavily
from apps.rag.search.jina_search import search_jina
from utils.misc import ( from utils.misc import (
calculate_sha256, calculate_sha256,
...@@ -856,6 +857,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -856,6 +857,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
) )
else: else:
raise Exception("No TAVILY_API_KEY found in environment variables") raise Exception("No TAVILY_API_KEY found in environment variables")
elif engine == "jina":
return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
else: else:
raise Exception("No search engine API key found in environment variables") raise Exception("No search engine API key found in environment variables")
......
import logging
import requests
from yarl import URL
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_jina(query: str, count: int) -> list[SearchResult]:
"""
Search using Jina's Search API and return the results as a list of SearchResult objects.
Args:
query (str): The query to search for
count (int): The number of results to return
Returns:
List[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
headers = {
"Accept": "application/json",
}
url = str(URL(jina_search_endpoint + query))
response = requests.get(url, headers=headers)
response.raise_for_status()
data = response.json()
results = []
for result in data["data"][:count]:
results.append(
SearchResult(
link=result["url"],
title=result.get("title"),
snippet=result.get("content"),
)
)
return results
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields("tool", valves=pw.TextField(null=True))
migrator.add_fields("function", valves=pw.TextField(null=True))
migrator.add_fields("function", is_active=pw.BooleanField(default=False))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("tool", "valves")
migrator.remove_fields("function", "valves")
migrator.remove_fields("function", "is_active")
...@@ -105,13 +105,15 @@ async def get_status(): ...@@ -105,13 +105,15 @@ async def get_status():
async def get_pipe_models(): async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe") pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = [] pipe_models = []
for pipe in pipes: for pipe in pipes:
# Check if function is already loaded # Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS: if pipe.id not in app.state.FUNCTIONS:
function_module, function_type = load_function_module_by_id(pipe.id) function_module, function_type, frontmatter = load_function_module_by_id(
pipe.id
)
app.state.FUNCTIONS[pipe.id] = function_module app.state.FUNCTIONS[pipe.id] = function_module
else: else:
function_module = app.state.FUNCTIONS[pipe.id] function_module = app.state.FUNCTIONS[pipe.id]
...@@ -132,7 +134,9 @@ async def get_pipe_models(): ...@@ -132,7 +134,9 @@ async def get_pipe_models():
manifold_pipe_name = p["name"] manifold_pipe_name = p["name"]
if hasattr(function_module, "name"): if hasattr(function_module, "name"):
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}" manifold_pipe_name = (
f"{function_module.name}{manifold_pipe_name}"
)
pipe_models.append( pipe_models.append(
{ {
......
...@@ -5,8 +5,11 @@ from typing import List, Union, Optional ...@@ -5,8 +5,11 @@ from typing import List, Union, Optional
import time import time
import logging import logging
from apps.webui.internal.db import DB, JSONField from apps.webui.internal.db import DB, JSONField
from apps.webui.models.users import Users
import json import json
import copy
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -25,6 +28,8 @@ class Function(Model): ...@@ -25,6 +28,8 @@ class Function(Model):
type = TextField() type = TextField()
content = TextField() content = TextField()
meta = JSONField() meta = JSONField()
valves = JSONField()
is_active = BooleanField(default=False)
updated_at = BigIntegerField() updated_at = BigIntegerField()
created_at = BigIntegerField() created_at = BigIntegerField()
...@@ -34,6 +39,7 @@ class Function(Model): ...@@ -34,6 +39,7 @@ class Function(Model):
class FunctionMeta(BaseModel): class FunctionMeta(BaseModel):
description: Optional[str] = None description: Optional[str] = None
manifest: Optional[dict] = {}
class FunctionModel(BaseModel): class FunctionModel(BaseModel):
...@@ -43,6 +49,7 @@ class FunctionModel(BaseModel): ...@@ -43,6 +49,7 @@ class FunctionModel(BaseModel):
type: str type: str
content: str content: str
meta: FunctionMeta meta: FunctionMeta
is_active: bool = False
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
...@@ -58,6 +65,7 @@ class FunctionResponse(BaseModel): ...@@ -58,6 +65,7 @@ class FunctionResponse(BaseModel):
type: str type: str
name: str name: str
meta: FunctionMeta meta: FunctionMeta
is_active: bool
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
...@@ -69,6 +77,10 @@ class FunctionForm(BaseModel): ...@@ -69,6 +77,10 @@ class FunctionForm(BaseModel):
meta: FunctionMeta meta: FunctionMeta
class FunctionValves(BaseModel):
valves: Optional[dict] = None
class FunctionsTable: class FunctionsTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
...@@ -104,17 +116,99 @@ class FunctionsTable: ...@@ -104,17 +116,99 @@ class FunctionsTable:
except: except:
return None return None
def get_functions(self) -> List[FunctionModel]: def get_functions(self, active_only=False) -> List[FunctionModel]:
if active_only:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(Function.is_active == True)
]
else:
return [ return [
FunctionModel(**model_to_dict(function)) for function in Function.select() FunctionModel(**model_to_dict(function))
for function in Function.select()
] ]
def get_functions_by_type(self, type: str) -> List[FunctionModel]: def get_functions_by_type(
self, type: str, active_only=False
) -> List[FunctionModel]:
if active_only:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(
Function.type == type, Function.is_active == True
)
]
else:
return [ return [
FunctionModel(**model_to_dict(function)) FunctionModel(**model_to_dict(function))
for function in Function.select().where(Function.type == type) for function in Function.select().where(Function.type == type)
] ]
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
try:
function = Function.get(Function.id == id)
return function.valves if function.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_function_valves_by_id(
self, id: str, valves: dict
) -> Optional[FunctionValves]:
try:
query = Function.update(
**{"valves": valves},
updated_at=int(time.time()),
).where(Function.id == id)
query.execute()
function = Function.get(Function.id == id)
return FunctionValves(**model_to_dict(function))
except:
return None
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
user_settings["functions"] = {}
if "valves" not in user_settings["functions"]:
user_settings["functions"]["valves"] = {}
return user_settings["functions"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
user_settings["functions"] = {}
if "valves" not in user_settings["functions"]:
user_settings["functions"]["valves"] = {}
user_settings["functions"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
return user_settings["functions"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try: try:
query = Function.update( query = Function.update(
...@@ -128,6 +222,19 @@ class FunctionsTable: ...@@ -128,6 +222,19 @@ class FunctionsTable:
except: except:
return None return None
def deactivate_all_functions(self) -> Optional[bool]:
try:
query = Function.update(
**{"is_active": False},
updated_at=int(time.time()),
)
query.execute()
return True
except:
return None
def delete_function_by_id(self, id: str) -> bool: def delete_function_by_id(self, id: str) -> bool:
try: try:
query = Function.delete().where((Function.id == id)) query = Function.delete().where((Function.id == id))
......
...@@ -5,8 +5,11 @@ from typing import List, Union, Optional ...@@ -5,8 +5,11 @@ from typing import List, Union, Optional
import time import time
import logging import logging
from apps.webui.internal.db import DB, JSONField from apps.webui.internal.db import DB, JSONField
from apps.webui.models.users import Users
import json import json
import copy
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -25,6 +28,7 @@ class Tool(Model): ...@@ -25,6 +28,7 @@ class Tool(Model):
content = TextField() content = TextField()
specs = JSONField() specs = JSONField()
meta = JSONField() meta = JSONField()
valves = JSONField()
updated_at = BigIntegerField() updated_at = BigIntegerField()
created_at = BigIntegerField() created_at = BigIntegerField()
...@@ -34,6 +38,7 @@ class Tool(Model): ...@@ -34,6 +38,7 @@ class Tool(Model):
class ToolMeta(BaseModel): class ToolMeta(BaseModel):
description: Optional[str] = None description: Optional[str] = None
manifest: Optional[dict] = {}
class ToolModel(BaseModel): class ToolModel(BaseModel):
...@@ -68,6 +73,10 @@ class ToolForm(BaseModel): ...@@ -68,6 +73,10 @@ class ToolForm(BaseModel):
meta: ToolMeta meta: ToolMeta
class ToolValves(BaseModel):
valves: Optional[dict] = None
class ToolsTable: class ToolsTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
...@@ -106,6 +115,69 @@ class ToolsTable: ...@@ -106,6 +115,69 @@ class ToolsTable:
def get_tools(self) -> List[ToolModel]: def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()] return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try:
tool = Tool.get(Tool.id == id)
return tool.valves if tool.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try:
query = Tool.update(
**{"valves": valves},
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolValves(**model_to_dict(tool))
except:
return None
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
user_settings["tools"] = {}
if "valves" not in user_settings["tools"]:
user_settings["tools"]["valves"] = {}
return user_settings["tools"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
user_settings["tools"] = {}
if "valves" not in user_settings["tools"]:
user_settings["tools"]["valves"] = {}
user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
return user_settings["tools"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try: try:
query = Tool.update( query = Tool.update(
......
...@@ -194,6 +194,29 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): ...@@ -194,6 +194,29 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
) )
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file:
file_path = Path(file.meta["path"])
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
return FileResponse(file_path)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################ ############################
# Delete File By Id # Delete File By Id
############################ ############################
......
...@@ -69,7 +69,10 @@ async def create_new_function( ...@@ -69,7 +69,10 @@ async def create_new_function(
with open(function_path, "w") as function_file: with open(function_path, "w") as function_file:
function_file.write(form_data.content) function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(form_data.id) function_module, function_type, frontmatter = load_function_module_by_id(
form_data.id
)
form_data.meta.manifest = frontmatter
FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module FUNCTIONS[form_data.id] = function_module
...@@ -117,13 +120,40 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)): ...@@ -117,13 +120,40 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
) )
############################
# ToggleFunctionById
############################
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
function = Functions.update_function_by_id(
id, {"is_active": not function.is_active}
)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################ ############################
# UpdateFunctionById # UpdateFunctionById
############################ ############################
@router.post("/id/{id}/update", response_model=Optional[FunctionModel]) @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_toolkit_by_id( async def update_function_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user) request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
): ):
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
...@@ -132,7 +162,8 @@ async def update_toolkit_by_id( ...@@ -132,7 +162,8 @@ async def update_toolkit_by_id(
with open(function_path, "w") as function_file: with open(function_path, "w") as function_file:
function_file.write(form_data.content) function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(id) function_module, function_type, frontmatter = load_function_module_by_id(id)
form_data.meta.manifest = frontmatter
FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[id] = function_module FUNCTIONS[id] = function_module
...@@ -178,3 +209,188 @@ async def delete_function_by_id( ...@@ -178,3 +209,188 @@ async def delete_function_by_id(
os.remove(function_path) os.remove(function_path)
return result return result
############################
# GetFunctionValves
############################
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
try:
valves = Functions.get_function_valves_by_id(id)
return valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# GetFunctionValvesSpec
############################
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_function_valves_spec_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
return Valves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateFunctionValves
############################
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
async def update_function_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
Functions.update_function_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# FunctionUserValves
############################
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
function = Functions.get_function_by_id(id)
if function:
try:
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_function_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
return UserValves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_function_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data)
Functions.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump()
)
return user_valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
...@@ -6,10 +6,12 @@ from fastapi import APIRouter ...@@ -6,10 +6,12 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user from utils.utils import get_admin_user, get_verified_user
from utils.tools import get_tools_specs from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -32,7 +34,7 @@ router = APIRouter() ...@@ -32,7 +34,7 @@ router = APIRouter()
@router.get("/", response_model=List[ToolResponse]) @router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_current_user)): async def get_toolkits(user=Depends(get_verified_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()] toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits return toolkits
...@@ -72,7 +74,8 @@ async def create_new_toolkit( ...@@ -72,7 +74,8 @@ async def create_new_toolkit(
with open(toolkit_path, "w") as tool_file: with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content) tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(form_data.id) toolkit_module, frontmatter = load_toolkit_module_by_id(form_data.id)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module TOOLS[form_data.id] = toolkit_module
...@@ -136,7 +139,8 @@ async def update_toolkit_by_id( ...@@ -136,7 +139,8 @@ async def update_toolkit_by_id(
with open(toolkit_path, "w") as tool_file: with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content) tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(id) toolkit_module, frontmatter = load_toolkit_module_by_id(id)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module TOOLS[id] = toolkit_module
...@@ -185,3 +189,187 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin ...@@ -185,3 +189,187 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin
os.remove(toolkit_path) os.remove(toolkit_path)
return result return result
############################
# GetToolValves
############################
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
try:
valves = Tools.get_tool_valves_by_id(id)
return valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# GetToolValvesSpec
############################
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_toolkit_valves_spec_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "Valves"):
Valves = toolkit_module.Valves
return Valves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolValves
############################
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
async def update_toolkit_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "Valves"):
Valves = toolkit_module.Valves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
Tools.update_tool_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# ToolUserValves
############################
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
try:
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_toolkit_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
UserValves = toolkit_module.UserValves
return UserValves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_toolkit_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
UserValves = toolkit_module.UserValves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data)
Tools.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump()
)
return user_valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
from importlib import util from importlib import util
import os import os
import re
from config import TOOLS_DIR, FUNCTIONS_DIR from config import TOOLS_DIR, FUNCTIONS_DIR
def extract_frontmatter(file_path):
"""
Extract frontmatter as a dictionary from the specified file path.
"""
frontmatter = {}
frontmatter_started = False
frontmatter_ended = False
frontmatter_pattern = re.compile(r"^\s*([a-z_]+):\s*(.*)\s*$", re.IGNORECASE)
try:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
if '"""' in line:
if not frontmatter_started:
frontmatter_started = True
continue # skip the line with the opening triple quotes
else:
frontmatter_ended = True
break
if frontmatter_started and not frontmatter_ended:
match = frontmatter_pattern.match(line)
if match:
key, value = match.groups()
frontmatter[key.strip()] = value.strip()
except FileNotFoundError:
print(f"Error: The file {file_path} does not exist.")
return {}
except Exception as e:
print(f"An error occurred: {e}")
return {}
return frontmatter
def load_toolkit_module_by_id(toolkit_id): def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py") toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path) spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec) module = util.module_from_spec(spec)
frontmatter = extract_frontmatter(toolkit_path)
try: try:
spec.loader.exec_module(module) spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}") print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"): if hasattr(module, "Tools"):
return module.Tools() return module.Tools(), frontmatter
else: else:
raise Exception("No Tools class found") raise Exception("No Tools class found")
except Exception as e: except Exception as e:
...@@ -28,14 +65,15 @@ def load_function_module_by_id(function_id): ...@@ -28,14 +65,15 @@ def load_function_module_by_id(function_id):
spec = util.spec_from_file_location(function_id, function_path) spec = util.spec_from_file_location(function_id, function_path)
module = util.module_from_spec(spec) module = util.module_from_spec(spec)
frontmatter = extract_frontmatter(function_path)
try: try:
spec.loader.exec_module(module) spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}") print(f"Loaded module: {module.__name__}")
if hasattr(module, "Pipe"): if hasattr(module, "Pipe"):
return module.Pipe(), "pipe" return module.Pipe(), "pipe", frontmatter
elif hasattr(module, "Filter"): elif hasattr(module, "Filter"):
return module.Filter(), "filter" return module.Filter(), "filter", frontmatter
else: else:
raise Exception("No Function class found") raise Exception("No Function class found")
except Exception as e: except Exception as e:
......
...@@ -167,6 +167,12 @@ for version in soup.find_all("h2"): ...@@ -167,6 +167,12 @@ for version in soup.find_all("h2"):
CHANGELOG = changelog_json CHANGELOG = changelog_json
####################################
# SAFE_MODE
####################################
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
#################################### ####################################
# WEBUI_BUILD_HASH # WEBUI_BUILD_HASH
#################################### ####################################
......
...@@ -62,9 +62,7 @@ from apps.webui.models.functions import Functions ...@@ -62,9 +62,7 @@ from apps.webui.models.functions import Functions
from apps.webui.models.users import Users from apps.webui.models.users import Users
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
from apps.webui.utils import load_toolkit_module_by_id
from utils.misc import parse_duration
from utils.utils import ( from utils.utils import (
get_admin_user, get_admin_user,
get_verified_user, get_verified_user,
...@@ -82,6 +80,7 @@ from utils.misc import ( ...@@ -82,6 +80,7 @@ from utils.misc import (
get_last_user_message, get_last_user_message,
add_or_update_system_message, add_or_update_system_message,
stream_message_template, stream_message_template,
parse_duration,
) )
from apps.rag.utils import get_rag_context, rag_template from apps.rag.utils import get_rag_context, rag_template
...@@ -113,6 +112,7 @@ from config import ( ...@@ -113,6 +112,7 @@ from config import (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
SAFE_MODE,
OAUTH_PROVIDERS, OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP, ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL, OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
...@@ -124,6 +124,11 @@ from config import ( ...@@ -124,6 +124,11 @@ from config import (
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from utils.webhook import post_webhook from utils.webhook import post_webhook
if SAFE_MODE:
print("SAFE MODE ENABLED")
Functions.deactivate_all_functions()
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
...@@ -271,7 +276,7 @@ async def get_function_call_response( ...@@ -271,7 +276,7 @@ async def get_function_call_response(
if tool_id in webui_app.state.TOOLS: if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id] toolkit_module = webui_app.state.TOOLS[tool_id]
else: else:
toolkit_module = load_toolkit_module_by_id(tool_id) toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False file_handler = False
...@@ -280,6 +285,14 @@ async def get_function_call_response( ...@@ -280,6 +285,14 @@ async def get_function_call_response(
file_handler = True file_handler = True
print("file_handler: ", file_handler) print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(
toolkit_module, "Valves"
):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(
**(valves if valves else {})
)
function = getattr(toolkit_module, result["name"]) function = getattr(toolkit_module, result["name"])
function_result = None function_result = None
try: try:
...@@ -289,16 +302,24 @@ async def get_function_call_response( ...@@ -289,16 +302,24 @@ async def get_function_call_response(
if "__user__" in sig.parameters: if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included # Call the function with the '__user__' parameter included
params = { __user__ = {
**params,
"__user__": {
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
"name": user.name, "name": user.name,
"role": user.role, "role": user.role,
},
} }
try:
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(
tool_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__messages__" in sig.parameters: if "__messages__" in sig.parameters:
# Call the function with the '__messages__' parameter included # Call the function with the '__messages__' parameter included
params = { params = {
...@@ -386,16 +407,34 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -386,16 +407,34 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get(
"priority", 0
)
return 0
filter_ids = [
function.id
for function in Functions.get_functions_by_type(
"filter", active_only=True
)
]
# Check if the model has any filters # Check if the model has any filters
if "info" in model and "meta" in model["info"]: if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []): filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id) filter = Functions.get_function_by_id(filter_id)
if filter: if filter:
if filter_id in webui_app.state.FUNCTIONS: if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id] function_module = webui_app.state.FUNCTIONS[filter_id]
else: else:
function_module, function_type = load_function_module_by_id( function_module, function_type, frontmatter = (
filter_id load_function_module_by_id(filter_id)
) )
webui_app.state.FUNCTIONS[filter_id] = function_module webui_app.state.FUNCTIONS[filter_id] = function_module
...@@ -403,30 +442,52 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -403,30 +442,52 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if hasattr(function_module, "file_handler"): if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try: try:
if hasattr(function_module, "inlet"): if hasattr(function_module, "inlet"):
inlet = function_module.inlet inlet = function_module.inlet
if inspect.iscoroutinefunction(inlet): # Get the signature of the function
data = await inlet( sig = inspect.signature(inlet)
data, params = {"body": data}
{
if "__user__" in sig.parameters:
__user__ = {
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
"name": user.name, "name": user.name,
"role": user.role, "role": user.role,
}, }
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
) )
else:
data = inlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
) )
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if inspect.iscoroutinefunction(inlet):
data = await inlet(**params)
else:
data = inlet(**params)
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
...@@ -857,12 +918,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ...@@ -857,12 +918,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
pipe = model.get("pipe") pipe = model.get("pipe")
if pipe: if pipe:
form_data["user"] = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
async def job(): async def job():
pipe_id = form_data["model"] pipe_id = form_data["model"]
...@@ -870,14 +925,62 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ...@@ -870,14 +925,62 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
pipe_id, sub_pipe_id = pipe_id.split(".", 1) pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id) print(pipe_id)
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe # Check if function is already loaded
if pipe_id not in webui_app.state.FUNCTIONS:
function_module, function_type, frontmatter = (
load_function_module_by_id(pipe_id)
)
webui_app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = webui_app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe = function_module.pipe
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
pipe_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if form_data["stream"]: if form_data["stream"]:
async def stream_content(): async def stream_content():
try:
if inspect.iscoroutinefunction(pipe): if inspect.iscoroutinefunction(pipe):
res = await pipe(body=form_data) res = await pipe(**params)
else: else:
res = pipe(body=form_data) res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str): if isinstance(res, str):
message = stream_message_template(form_data["model"], res) message = stream_message_template(form_data["model"], res)
...@@ -922,10 +1025,20 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ...@@ -922,10 +1025,20 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
stream_content(), media_type="text/event-stream" stream_content(), media_type="text/event-stream"
) )
else: else:
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if inspect.iscoroutinefunction(pipe): if inspect.iscoroutinefunction(pipe):
res = await pipe(body=form_data) res = await pipe(**params)
else: else:
res = pipe(body=form_data) res = pipe(**params)
if isinstance(res, dict): if isinstance(res, dict):
return res return res
...@@ -1008,7 +1121,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1008,7 +1121,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
f"{url}/{filter['id']}/filter/outlet", f"{url}/{filter['id']}/filter/outlet",
headers=headers, headers=headers,
json={ json={
"user": {"id": user.id, "name": user.name, "role": user.role}, "user": {
"id": user.id,
"name": user.name,
"email": user.email,
"role": user.role,
},
"body": data, "body": data,
}, },
) )
...@@ -1033,42 +1151,81 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1033,42 +1151,81 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
else: else:
pass pass
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
# Check if the model has any filters # Check if the model has any filters
if "info" in model and "meta" in model["info"]: if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []): filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
# Sort filter_ids by priority, using the get_priority function
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id) filter = Functions.get_function_by_id(filter_id)
if filter: if filter:
if filter_id in webui_app.state.FUNCTIONS: if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id] function_module = webui_app.state.FUNCTIONS[filter_id]
else: else:
function_module, function_type = load_function_module_by_id( function_module, function_type, frontmatter = (
filter_id load_function_module_by_id(filter_id)
) )
webui_app.state.FUNCTIONS[filter_id] = function_module webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try: try:
if hasattr(function_module, "outlet"): if hasattr(function_module, "outlet"):
outlet = function_module.outlet outlet = function_module.outlet
if inspect.iscoroutinefunction(outlet):
data = await outlet( # Get the signature of the function
data, sig = inspect.signature(outlet)
{ params = {"body": data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
"name": user.name, "name": user.name,
"role": user.role, "role": user.role,
}, }
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
) )
else:
data = outlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
) )
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
...@@ -1989,7 +2146,6 @@ async def get_manifest_json(): ...@@ -1989,7 +2146,6 @@ async def get_manifest_json():
"start_url": "/", "start_url": "/",
"display": "standalone", "display": "standalone",
"background_color": "#343541", "background_color": "#343541",
"theme_color": "#343541",
"orientation": "portrait-primary", "orientation": "portrait-primary",
"icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}], "icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
} }
......
...@@ -17,7 +17,9 @@ peewee-migrate==1.12.2 ...@@ -17,7 +17,9 @@ peewee-migrate==1.12.2
psycopg2-binary==2.9.9 psycopg2-binary==2.9.9
PyMySQL==1.1.1 PyMySQL==1.1.1
bcrypt==4.1.3 bcrypt==4.1.3
SQLAlchemy
pymongo
redis
boto3==1.34.110 boto3==1.34.110
argon2-cffi==23.1.0 argon2-cffi==23.1.0
......
...@@ -20,7 +20,9 @@ def get_tools_specs(tools) -> List[dict]: ...@@ -20,7 +20,9 @@ def get_tools_specs(tools) -> List[dict]:
function_list = [ function_list = [
{"name": func, "function": getattr(tools, func)} {"name": func, "function": getattr(tools, func)}
for func in dir(tools) for func in dir(tools)
if callable(getattr(tools, func)) and not func.startswith("__") if callable(getattr(tools, func))
and not func.startswith("__")
and not inspect.isclass(getattr(tools, func))
] ]
specs = [] specs = []
...@@ -65,6 +67,7 @@ def get_tools_specs(tools) -> List[dict]: ...@@ -65,6 +67,7 @@ def get_tools_specs(tools) -> List[dict]:
function function
).parameters.items() ).parameters.items()
if param.default is param.empty if param.default is param.empty
and not (name.startswith("__") and name.endswith("__"))
], ],
}, },
} }
......
...@@ -32,6 +32,10 @@ math { ...@@ -32,6 +32,10 @@ math {
@apply underline; @apply underline;
} }
iframe {
@apply rounded-lg;
}
ol > li { ol > li {
counter-increment: list-number; counter-increment: list-number;
display: block; display: block;
......
...@@ -191,3 +191,233 @@ export const deleteFunctionById = async (token: string, id: string) => { ...@@ -191,3 +191,233 @@ export const deleteFunctionById = async (token: string, id: string) => {
return res; return res;
}; };
export const toggleFunctionById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/toggle`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctionValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctionValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFunctionValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateUserValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
...@@ -191,3 +191,201 @@ export const deleteToolById = async (token: string, id: string) => { ...@@ -191,3 +191,201 @@ export const deleteToolById = async (token: string, id: string) => {
return res; return res;
}; };
export const getToolValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getToolValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateToolValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateUserValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
'serper', 'serper',
'serply', 'serply',
'duckduckgo', 'duckduckgo',
'tavily' 'tavily',
'jina'
]; ];
let youtubeLanguage = 'en'; let youtubeLanguage = 'en';
......
<script>
import { getContext } from 'svelte';
import Modal from '../common/Modal.svelte';
import Database from './Settings/Database.svelte';
import General from './Settings/General.svelte';
import Users from './Settings/Users.svelte';
import Banners from '$lib/components/admin/Settings/Banners.svelte';
import { toast } from 'svelte-sonner';
import Pipelines from './Settings/Pipelines.svelte';
const i18n = getContext('i18n');
export let show = false;
let selectedTab = 'general';
</script>
<Modal bind:show>
<div>
<div class=" flex justify-between dark:text-gray-300 px-5 pt-4 pb-2">
<div class=" text-lg font-medium self-center">{$i18n.t('Admin Settings')}</div>
<button
class="self-center"
on:click={() => {
show = false;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button>
</div>
</div>
</Modal>
...@@ -127,6 +127,42 @@ ...@@ -127,6 +127,42 @@
} }
onMount(async () => { onMount(async () => {
const onMessageHandler = async (event) => {
if (event.origin === window.origin) {
// Replace with your iframe's origin
console.log('Message received from iframe:', event.data);
if (event.data.type === 'input:prompt') {
console.log(event.data.text);
const inputElement = document.getElementById('chat-textarea');
if (inputElement) {
prompt = event.data.text;
inputElement.focus();
}
}
if (event.data.type === 'action:submit') {
console.log(event.data.text);
if (prompt !== '') {
await tick();
submitPrompt(prompt);
}
}
if (event.data.type === 'input:prompt:submit') {
console.log(event.data.text);
if (prompt !== '') {
await tick();
submitPrompt(event.data.text);
}
}
}
};
window.addEventListener('message', onMessageHandler);
if (!$chatId) { if (!$chatId) {
chatId.subscribe(async (value) => { chatId.subscribe(async (value) => {
if (!value) { if (!value) {
...@@ -138,6 +174,10 @@ ...@@ -138,6 +174,10 @@
await goto('/'); await goto('/');
} }
} }
return () => {
window.removeEventListener('message', onMessageHandler);
};
}); });
////////////////////////// //////////////////////////
...@@ -600,10 +640,14 @@ ...@@ -600,10 +640,14 @@
files = model.info.meta.knowledge; files = model.info.meta.knowledge;
} }
const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1); const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
files = [ files = [
...files, ...files,
...(lastUserMessage?.files?.filter((item) => ...(lastUserMessage?.files?.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type) ['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ?? []),
...(responseMessage?.files?.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ?? []) ) ?? [])
].filter( ].filter(
// Remove duplicates // Remove duplicates
...@@ -844,6 +888,9 @@ ...@@ -844,6 +888,9 @@
...files, ...files,
...(lastUserMessage?.files?.filter((item) => ...(lastUserMessage?.files?.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type) ['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ?? []),
...(responseMessage?.files?.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ?? []) ) ?? [])
].filter( ].filter(
// Remove duplicates // Remove duplicates
...@@ -1213,6 +1260,7 @@ ...@@ -1213,6 +1260,7 @@
const getWebSearchResults = async (model: string, parentId: string, responseId: string) => { const getWebSearchResults = async (model: string, parentId: string, responseId: string) => {
const responseMessage = history.messages[responseId]; const responseMessage = history.messages[responseId];
const userMessage = history.messages[parentId];
responseMessage.statusHistory = [ responseMessage.statusHistory = [
{ {
...@@ -1223,7 +1271,7 @@ ...@@ -1223,7 +1271,7 @@
]; ];
messages = messages; messages = messages;
const prompt = history.messages[parentId].content; const prompt = userMessage.content;
let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch( let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch(
(error) => { (error) => {
console.log(error); console.log(error);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment