Commit f26d80dc authored by Jun Siang Cheah's avatar Jun Siang Cheah
Browse files

Merge remote-tracking branch 'upstream/dev' into feat/oauth

parents 99e7b328 f54a66b8
......@@ -77,6 +77,7 @@ from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo
from apps.rag.search.tavily import search_tavily
from apps.rag.search.jina_search import search_jina
from utils.misc import (
calculate_sha256,
......@@ -856,6 +857,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
elif engine == "jina":
return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
else:
raise Exception("No search engine API key found in environment variables")
......
import logging
import requests
from yarl import URL
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_jina(query: str, count: int) -> list[SearchResult]:
"""
Search using Jina's Search API and return the results as a list of SearchResult objects.
Args:
query (str): The query to search for
count (int): The number of results to return
Returns:
List[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
headers = {
"Accept": "application/json",
}
url = str(URL(jina_search_endpoint + query))
response = requests.get(url, headers=headers)
response.raise_for_status()
data = response.json()
results = []
for result in data["data"][:count]:
results.append(
SearchResult(
link=result["url"],
title=result.get("title"),
snippet=result.get("content"),
)
)
return results
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields("tool", valves=pw.TextField(null=True))
migrator.add_fields("function", valves=pw.TextField(null=True))
migrator.add_fields("function", is_active=pw.BooleanField(default=False))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("tool", "valves")
migrator.remove_fields("function", "valves")
migrator.remove_fields("function", "is_active")
......@@ -105,13 +105,15 @@ async def get_status():
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe")
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = []
for pipe in pipes:
# Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type = load_function_module_by_id(pipe.id)
function_module, function_type, frontmatter = load_function_module_by_id(
pipe.id
)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
......@@ -132,7 +134,9 @@ async def get_pipe_models():
manifold_pipe_name = p["name"]
if hasattr(function_module, "name"):
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
manifold_pipe_name = (
f"{function_module.name}{manifold_pipe_name}"
)
pipe_models.append(
{
......
......@@ -5,8 +5,11 @@ from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
from apps.webui.models.users import Users
import json
import copy
from config import SRC_LOG_LEVELS
......@@ -25,6 +28,8 @@ class Function(Model):
type = TextField()
content = TextField()
meta = JSONField()
valves = JSONField()
is_active = BooleanField(default=False)
updated_at = BigIntegerField()
created_at = BigIntegerField()
......@@ -34,6 +39,7 @@ class Function(Model):
class FunctionMeta(BaseModel):
description: Optional[str] = None
manifest: Optional[dict] = {}
class FunctionModel(BaseModel):
......@@ -43,6 +49,7 @@ class FunctionModel(BaseModel):
type: str
content: str
meta: FunctionMeta
is_active: bool = False
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
......@@ -58,6 +65,7 @@ class FunctionResponse(BaseModel):
type: str
name: str
meta: FunctionMeta
is_active: bool
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
......@@ -69,6 +77,10 @@ class FunctionForm(BaseModel):
meta: FunctionMeta
class FunctionValves(BaseModel):
valves: Optional[dict] = None
class FunctionsTable:
def __init__(self, db):
self.db = db
......@@ -104,16 +116,98 @@ class FunctionsTable:
except:
return None
def get_functions(self) -> List[FunctionModel]:
return [
FunctionModel(**model_to_dict(function)) for function in Function.select()
]
def get_functions(self, active_only=False) -> List[FunctionModel]:
if active_only:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(Function.is_active == True)
]
else:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select()
]
def get_functions_by_type(
self, type: str, active_only=False
) -> List[FunctionModel]:
if active_only:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(
Function.type == type, Function.is_active == True
)
]
else:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(Function.type == type)
]
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
try:
function = Function.get(Function.id == id)
return function.valves if function.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
return None
def get_functions_by_type(self, type: str) -> List[FunctionModel]:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(Function.type == type)
]
def update_function_valves_by_id(
self, id: str, valves: dict
) -> Optional[FunctionValves]:
try:
query = Function.update(
**{"valves": valves},
updated_at=int(time.time()),
).where(Function.id == id)
query.execute()
function = Function.get(Function.id == id)
return FunctionValves(**model_to_dict(function))
except:
return None
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
user_settings["functions"] = {}
if "valves" not in user_settings["functions"]:
user_settings["functions"]["valves"] = {}
return user_settings["functions"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
user_settings["functions"] = {}
if "valves" not in user_settings["functions"]:
user_settings["functions"]["valves"] = {}
user_settings["functions"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
return user_settings["functions"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try:
......@@ -128,6 +222,19 @@ class FunctionsTable:
except:
return None
def deactivate_all_functions(self) -> Optional[bool]:
try:
query = Function.update(
**{"is_active": False},
updated_at=int(time.time()),
)
query.execute()
return True
except:
return None
def delete_function_by_id(self, id: str) -> bool:
try:
query = Function.delete().where((Function.id == id))
......
......@@ -5,8 +5,11 @@ from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
from apps.webui.models.users import Users
import json
import copy
from config import SRC_LOG_LEVELS
......@@ -25,6 +28,7 @@ class Tool(Model):
content = TextField()
specs = JSONField()
meta = JSONField()
valves = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
......@@ -34,6 +38,7 @@ class Tool(Model):
class ToolMeta(BaseModel):
description: Optional[str] = None
manifest: Optional[dict] = {}
class ToolModel(BaseModel):
......@@ -68,6 +73,10 @@ class ToolForm(BaseModel):
meta: ToolMeta
class ToolValves(BaseModel):
valves: Optional[dict] = None
class ToolsTable:
def __init__(self, db):
self.db = db
......@@ -106,6 +115,69 @@ class ToolsTable:
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try:
tool = Tool.get(Tool.id == id)
return tool.valves if tool.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try:
query = Tool.update(
**{"valves": valves},
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolValves(**model_to_dict(tool))
except:
return None
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
user_settings["tools"] = {}
if "valves" not in user_settings["tools"]:
user_settings["tools"]["valves"] = {}
return user_settings["tools"]["valves"].get(id, {})
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
user_settings["tools"] = {}
if "valves" not in user_settings["tools"]:
user_settings["tools"]["valves"] = {}
user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
return user_settings["tools"]["valves"][id]
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
......
......@@ -194,6 +194,29 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
)
@router.get("/{id}/content/{file_name}", response_model=Optional[FileModel])
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file:
file_path = Path(file.meta["path"])
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
return FileResponse(file_path)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# Delete File By Id
############################
......
......@@ -69,7 +69,10 @@ async def create_new_function(
with open(function_path, "w") as function_file:
function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(form_data.id)
function_module, function_type, frontmatter = load_function_module_by_id(
form_data.id
)
form_data.meta.manifest = frontmatter
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module
......@@ -117,13 +120,40 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user)):
)
############################
# ToggleFunctionById
############################
@router.post("/id/{id}/toggle", response_model=Optional[FunctionModel])
async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
function = Functions.update_function_by_id(
id, {"is_active": not function.is_active}
)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateFunctionById
############################
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_toolkit_by_id(
async def update_function_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
):
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
......@@ -132,7 +162,8 @@ async def update_toolkit_by_id(
with open(function_path, "w") as function_file:
function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(id)
function_module, function_type, frontmatter = load_function_module_by_id(id)
form_data.meta.manifest = frontmatter
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[id] = function_module
......@@ -178,3 +209,188 @@ async def delete_function_by_id(
os.remove(function_path)
return result
############################
# GetFunctionValves
############################
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
try:
valves = Functions.get_function_valves_by_id(id)
return valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# GetFunctionValvesSpec
############################
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_function_valves_spec_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
return Valves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateFunctionValves
############################
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
async def update_function_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "Valves"):
Valves = function_module.Valves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
Functions.update_function_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# FunctionUserValves
############################
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user)):
function = Functions.get_function_by_id(id)
if function:
try:
user_valves = Functions.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_function_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
return UserValves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_function_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
function = Functions.get_function_by_id(id)
if function:
if id in request.app.state.FUNCTIONS:
function_module = request.app.state.FUNCTIONS[id]
else:
function_module, function_type, frontmatter = load_function_module_by_id(id)
request.app.state.FUNCTIONS[id] = function_module
if hasattr(function_module, "UserValves"):
UserValves = function_module.UserValves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data)
Functions.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump()
)
return user_valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
......@@ -6,10 +6,12 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user
from utils.utils import get_admin_user, get_verified_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
......@@ -32,7 +34,7 @@ router = APIRouter()
@router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_current_user)):
async def get_toolkits(user=Depends(get_verified_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
......@@ -72,7 +74,8 @@ async def create_new_toolkit(
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(form_data.id)
toolkit_module, frontmatter = load_toolkit_module_by_id(form_data.id)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module
......@@ -136,7 +139,8 @@ async def update_toolkit_by_id(
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(id)
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
form_data.meta.manifest = frontmatter
TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module
......@@ -185,3 +189,187 @@ async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin
os.remove(toolkit_path)
return result
############################
# GetToolValves
############################
@router.get("/id/{id}/valves", response_model=Optional[dict])
async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
try:
valves = Tools.get_tool_valves_by_id(id)
return valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# GetToolValvesSpec
############################
@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
async def get_toolkit_valves_spec_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "Valves"):
Valves = toolkit_module.Valves
return Valves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolValves
############################
@router.post("/id/{id}/valves/update", response_model=Optional[dict])
async def update_toolkit_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_admin_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "Valves"):
Valves = toolkit_module.Valves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
valves = Valves(**form_data)
Tools.update_tool_valves_by_id(id, valves.model_dump())
return valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# ToolUserValves
############################
@router.get("/id/{id}/valves/user", response_model=Optional[dict])
async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
try:
user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
return user_valves
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
async def get_toolkit_user_valves_spec_by_id(
request: Request, id: str, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
UserValves = toolkit_module.UserValves
return UserValves.schema()
return None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
async def update_toolkit_user_valves_by_id(
request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
if id in request.app.state.TOOLS:
toolkit_module = request.app.state.TOOLS[id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
request.app.state.TOOLS[id] = toolkit_module
if hasattr(toolkit_module, "UserValves"):
UserValves = toolkit_module.UserValves
try:
form_data = {k: v for k, v in form_data.items() if v is not None}
user_valves = UserValves(**form_data)
Tools.update_user_valves_by_id_and_user_id(
id, user.id, user_valves.model_dump()
)
return user_valves.model_dump()
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
from importlib import util
import os
import re
from config import TOOLS_DIR, FUNCTIONS_DIR
def extract_frontmatter(file_path):
"""
Extract frontmatter as a dictionary from the specified file path.
"""
frontmatter = {}
frontmatter_started = False
frontmatter_ended = False
frontmatter_pattern = re.compile(r"^\s*([a-z_]+):\s*(.*)\s*$", re.IGNORECASE)
try:
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
if '"""' in line:
if not frontmatter_started:
frontmatter_started = True
continue # skip the line with the opening triple quotes
else:
frontmatter_ended = True
break
if frontmatter_started and not frontmatter_ended:
match = frontmatter_pattern.match(line)
if match:
key, value = match.groups()
frontmatter[key.strip()] = value.strip()
except FileNotFoundError:
print(f"Error: The file {file_path} does not exist.")
return {}
except Exception as e:
print(f"An error occurred: {e}")
return {}
return frontmatter
def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec)
frontmatter = extract_frontmatter(toolkit_path)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
return module.Tools()
return module.Tools(), frontmatter
else:
raise Exception("No Tools class found")
except Exception as e:
......@@ -28,14 +65,15 @@ def load_function_module_by_id(function_id):
spec = util.spec_from_file_location(function_id, function_path)
module = util.module_from_spec(spec)
frontmatter = extract_frontmatter(function_path)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Pipe"):
return module.Pipe(), "pipe"
return module.Pipe(), "pipe", frontmatter
elif hasattr(module, "Filter"):
return module.Filter(), "filter"
return module.Filter(), "filter", frontmatter
else:
raise Exception("No Function class found")
except Exception as e:
......
......@@ -167,6 +167,12 @@ for version in soup.find_all("h2"):
CHANGELOG = changelog_json
####################################
# SAFE_MODE
####################################
SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true"
####################################
# WEBUI_BUILD_HASH
####################################
......
......@@ -62,9 +62,7 @@ from apps.webui.models.functions import Functions
from apps.webui.models.users import Users
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
from apps.webui.utils import load_toolkit_module_by_id
from utils.misc import parse_duration
from utils.utils import (
get_admin_user,
get_verified_user,
......@@ -82,6 +80,7 @@ from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
parse_duration,
)
from apps.rag.utils import get_rag_context, rag_template
......@@ -113,6 +112,7 @@ from config import (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
SAFE_MODE,
OAUTH_PROVIDERS,
ENABLE_OAUTH_SIGNUP,
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
......@@ -124,6 +124,11 @@ from config import (
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from utils.webhook import post_webhook
if SAFE_MODE:
print("SAFE MODE ENABLED")
Functions.deactivate_all_functions()
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
......@@ -271,7 +276,7 @@ async def get_function_call_response(
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
......@@ -280,6 +285,14 @@ async def get_function_call_response(
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(
toolkit_module, "Valves"
):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(
**(valves if valves else {})
)
function = getattr(toolkit_module, result["name"])
function_result = None
try:
......@@ -289,16 +302,24 @@ async def get_function_call_response(
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
params = {
**params,
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(
tool_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__messages__" in sig.parameters:
# Call the function with the '__messages__' parameter included
params = {
......@@ -386,54 +407,94 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
)
model = app.state.MODELS[model_id]
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get(
"priority", 0
)
return 0
filter_ids = [
function.id
for function in Functions.get_functions_by_type(
"filter", active_only=True
)
]
# Check if the model has any filters
if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
try:
if hasattr(function_module, "inlet"):
inlet = function_module.inlet
try:
if hasattr(function_module, "inlet"):
inlet = function_module.inlet
if inspect.iscoroutinefunction(inlet):
data = await inlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
else:
data = inlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": data}
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if inspect.iscoroutinefunction(inlet):
data = await inlet(**params)
else:
data = inlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Set the task model
task_model_id = data["model"]
......@@ -857,12 +918,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
pipe = model.get("pipe")
if pipe:
form_data["user"] = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
async def job():
pipe_id = form_data["model"]
......@@ -870,14 +925,62 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
# Check if function is already loaded
if pipe_id not in webui_app.state.FUNCTIONS:
function_module, function_type, frontmatter = (
load_function_module_by_id(pipe_id)
)
webui_app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = webui_app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe = function_module.pipe
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
pipe_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if form_data["stream"]:
async def stream_content():
if inspect.iscoroutinefunction(pipe):
res = await pipe(body=form_data)
else:
res = pipe(body=form_data)
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
......@@ -922,10 +1025,20 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
stream_content(), media_type="text/event-stream"
)
else:
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if inspect.iscoroutinefunction(pipe):
res = await pipe(body=form_data)
res = await pipe(**params)
else:
res = pipe(body=form_data)
res = pipe(**params)
if isinstance(res, dict):
return res
......@@ -1008,7 +1121,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"user": {
"id": user.id,
"name": user.name,
"email": user.email,
"role": user.role,
},
"body": data,
},
)
......@@ -1033,49 +1151,88 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
else:
pass
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
# Check if the model has any filters
if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
try:
if hasattr(function_module, "outlet"):
outlet = function_module.outlet
if inspect.iscoroutinefunction(outlet):
data = await outlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
else:
data = outlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
# Sort filter_ids by priority, using the get_priority function
filter_ids.sort(key=get_priority)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try:
if hasattr(function_module, "outlet"):
outlet = function_module.outlet
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data
......@@ -1989,7 +2146,6 @@ async def get_manifest_json():
"start_url": "/",
"display": "standalone",
"background_color": "#343541",
"theme_color": "#343541",
"orientation": "portrait-primary",
"icons": [{"src": "/static/logo.png", "type": "image/png", "sizes": "500x500"}],
}
......
......@@ -17,7 +17,9 @@ peewee-migrate==1.12.2
psycopg2-binary==2.9.9
PyMySQL==1.1.1
bcrypt==4.1.3
SQLAlchemy
pymongo
redis
boto3==1.34.110
argon2-cffi==23.1.0
......
......@@ -20,7 +20,9 @@ def get_tools_specs(tools) -> List[dict]:
function_list = [
{"name": func, "function": getattr(tools, func)}
for func in dir(tools)
if callable(getattr(tools, func)) and not func.startswith("__")
if callable(getattr(tools, func))
and not func.startswith("__")
and not inspect.isclass(getattr(tools, func))
]
specs = []
......@@ -65,6 +67,7 @@ def get_tools_specs(tools) -> List[dict]:
function
).parameters.items()
if param.default is param.empty
and not (name.startswith("__") and name.endswith("__"))
],
},
}
......
......@@ -32,6 +32,10 @@ math {
@apply underline;
}
iframe {
@apply rounded-lg;
}
ol > li {
counter-increment: list-number;
display: block;
......
......@@ -191,3 +191,233 @@ export const deleteFunctionById = async (token: string, id: string) => {
return res;
};
export const toggleFunctionById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/toggle`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctionValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getFunctionValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateFunctionValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateUserValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/valves/user/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
......@@ -191,3 +191,201 @@ export const deleteToolById = async (token: string, id: string) => {
return res;
};
export const getToolValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getToolValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateToolValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserValvesSpecById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user/spec`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateUserValvesById = async (token: string, id: string, valves: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/valves/user/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}`
},
body: JSON.stringify({
...valves
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err.detail;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
......@@ -19,7 +19,8 @@
'serper',
'serply',
'duckduckgo',
'tavily'
'tavily',
'jina'
];
let youtubeLanguage = 'en';
......
<script>
import { getContext } from 'svelte';
import Modal from '../common/Modal.svelte';
import Database from './Settings/Database.svelte';
import General from './Settings/General.svelte';
import Users from './Settings/Users.svelte';
import Banners from '$lib/components/admin/Settings/Banners.svelte';
import { toast } from 'svelte-sonner';
import Pipelines from './Settings/Pipelines.svelte';
const i18n = getContext('i18n');
export let show = false;
let selectedTab = 'general';
</script>
<Modal bind:show>
<div>
<div class=" flex justify-between dark:text-gray-300 px-5 pt-4 pb-2">
<div class=" text-lg font-medium self-center">{$i18n.t('Admin Settings')}</div>
<button
class="self-center"
on:click={() => {
show = false;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 20 20"
fill="currentColor"
class="w-5 h-5"
>
<path
d="M6.28 5.22a.75.75 0 00-1.06 1.06L8.94 10l-3.72 3.72a.75.75 0 101.06 1.06L10 11.06l3.72 3.72a.75.75 0 101.06-1.06L11.06 10l3.72-3.72a.75.75 0 00-1.06-1.06L10 8.94 6.28 5.22z"
/>
</svg>
</button>
</div>
</div>
</Modal>
......@@ -127,6 +127,42 @@
}
onMount(async () => {
const onMessageHandler = async (event) => {
if (event.origin === window.origin) {
// Replace with your iframe's origin
console.log('Message received from iframe:', event.data);
if (event.data.type === 'input:prompt') {
console.log(event.data.text);
const inputElement = document.getElementById('chat-textarea');
if (inputElement) {
prompt = event.data.text;
inputElement.focus();
}
}
if (event.data.type === 'action:submit') {
console.log(event.data.text);
if (prompt !== '') {
await tick();
submitPrompt(prompt);
}
}
if (event.data.type === 'input:prompt:submit') {
console.log(event.data.text);
if (prompt !== '') {
await tick();
submitPrompt(event.data.text);
}
}
}
};
window.addEventListener('message', onMessageHandler);
if (!$chatId) {
chatId.subscribe(async (value) => {
if (!value) {
......@@ -138,6 +174,10 @@
await goto('/');
}
}
return () => {
window.removeEventListener('message', onMessageHandler);
};
});
//////////////////////////
......@@ -600,10 +640,14 @@
files = model.info.meta.knowledge;
}
const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
files = [
...files,
...(lastUserMessage?.files?.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ?? []),
...(responseMessage?.files?.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ?? [])
].filter(
// Remove duplicates
......@@ -844,6 +888,9 @@
...files,
...(lastUserMessage?.files?.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ?? []),
...(responseMessage?.files?.filter((item) =>
['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
) ?? [])
].filter(
// Remove duplicates
......@@ -1213,6 +1260,7 @@
const getWebSearchResults = async (model: string, parentId: string, responseId: string) => {
const responseMessage = history.messages[responseId];
const userMessage = history.messages[parentId];
responseMessage.statusHistory = [
{
......@@ -1223,7 +1271,7 @@
];
messages = messages;
const prompt = history.messages[parentId].content;
const prompt = userMessage.content;
let searchQuery = await generateSearchQuery(localStorage.token, model, messages, prompt).catch(
(error) => {
console.log(error);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment