"vscode:/vscode.git/clone" did not exist on "caae2c98cf7eff885204aa25c710828004e0534c"
Unverified Commit c41b33c9 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3013 from open-webui/dev

0.3.3
parents 06976c45 7131ac24
......@@ -5,6 +5,20 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.3] - 2024-06-12
### Added
- **🛠️ Native Python Function Calling**: Introducing native Python function calling within Open WebUI. We’ve also included a built-in code editor to seamlessly develop and integrate function code within the 'Tools' workspace. With this, you can significantly enhance your LLM’s capabilities by creating custom RAG pipelines, web search tools, and even agent-like features such as sending Discord messages.
- **🌐 DuckDuckGo Integration**: Added DuckDuckGo as a web search provider, giving you more search options.
- **🌏 Enhanced Translations**: Improved translations for Vietnamese and Chinese languages, making the interface more accessible.
### Fixed
- **🔗 Web Search URL Error Handling**: Fixed the issue where a single URL error would disrupt the data loading process in Web Search mode. Now, such errors will be handled gracefully to ensure uninterrupted data loading.
- **🖥️ Frontend Responsiveness**: Resolved the problem where the frontend would stop responding if the backend encounters an error while downloading a model. Improved error handling to maintain frontend stability.
- **🔧 Dependency Issues in pip**: Fixed issues related to pip installations, ensuring all dependencies are correctly managed to prevent installation errors.
## [0.3.2] - 2024-06-10
### Added
......
......@@ -29,8 +29,12 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment.
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, and `Serply` and inject the results directly into your chat experience.
......
......@@ -8,13 +8,15 @@ from fastapi import (
Form,
)
from fastapi.middleware.cors import CORSMiddleware
import requests
import os, shutil, logging, re
from datetime import datetime
from pathlib import Path
from typing import List, Union, Sequence
from typing import List, Union, Sequence, Iterator, Any
from chromadb.utils.batch_utils import create_batches
from langchain_core.documents import Document
from langchain_community.document_loaders import (
WebBaseLoader,
......@@ -70,6 +72,7 @@ from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo
from utils.misc import (
calculate_sha256,
......@@ -701,7 +704,7 @@ def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
# Check if the URL is valid
if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(
return SafeWebBaseLoader(
url,
verify_ssl=verify_ssl,
requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
......@@ -714,18 +717,13 @@ def validate_url(url: Union[str, Sequence[str]]):
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
# Check if the URL exists by making a HEAD request
try:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise ValueError(ERROR_MESSAGES.INVALID_URL)
except requests.exceptions.RequestException:
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return True
elif isinstance(url, Sequence):
return all(validate_url(u) for u in url)
......@@ -733,17 +731,6 @@ def validate_url(url: Union[str, Sequence[str]]):
return False
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
def search_web(engine: str, query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order:
......@@ -820,6 +807,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
else:
raise Exception("No search engine API key found in environment variables")
......@@ -827,7 +816,9 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
try:
logging.info(f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}")
logging.info(
f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
)
web_results = search_web(
app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
)
......@@ -1238,6 +1229,33 @@ def reset(user=Depends(get_admin_user)) -> bool:
return True
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths:
try:
soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
text = soup.get_text(**self.bs_get_text_kwargs)
# Build metadata
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.error(f"Error loading {path}: {e}")
if ENV == "dev":
@app.get("/ef")
......
import logging
from apps.rag.search.main import SearchResult
from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args:
query (str): The query to search for
count (int): The number of results to return
Returns:
List[SearchResult]: A list of search results
"""
# Use the DDGS context manager to create a DDGS object
with DDGS() as ddgs:
# Use the ddgs.text() method to perform the search
ddgs_gen = ddgs.text(
query, safesearch="moderate", max_results=count, backend="api"
)
# Check if there are search results
if ddgs_gen:
# Convert the search results into a list
search_results = [r for r in ddgs_gen]
# Create an empty list to store the SearchResult objects
results = []
# Iterate over each search result
for result in search_results:
# Create a SearchResult object and append it to the results list
results.append(
SearchResult(
link=result["href"],
title=result.get("title"),
snippet=result.get("body"),
)
)
print(results)
# Return the list of search results
return results
......@@ -12,14 +12,14 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serply(
api_key: str,
query: str,
count: int,
hl: str = "us",
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US"
) -> list[SearchResult]:
api_key: str,
query: str,
count: int,
hl: str = "us",
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US",
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
......@@ -37,7 +37,7 @@ def search_serply(
"language": "en",
"num": limit,
"gl": proxy_location.upper(),
"hl": hl.lower()
"hl": hl.lower(),
}
url = f"{url}{urlencode(query_payload)}"
......@@ -45,7 +45,7 @@ def search_serply(
"X-API-KEY": api_key,
"X-User-Agent": device_type,
"User-Agent": "open-webui",
"X-Proxy-Location": proxy_location
"X-Proxy-Location": proxy_location,
}
response = requests.request("GET", url, headers=headers)
......
......@@ -236,10 +236,9 @@ def get_embedding_function(
return lambda query: generate_multiple(query, func)
def rag_messages(
def get_rag_context(
docs,
messages,
template,
embedding_function,
k,
reranking_function,
......@@ -318,16 +317,7 @@ def rag_messages(
context_string = context_string.strip()
ra_content = rag_template(
template=template,
context=context_string,
query=query,
)
log.debug(f"ra_content: {ra_content}")
messages = add_or_update_system_message(ra_content, messages)
return messages, citations
return context_string, citations
def get_model_path(model: str, update_model: bool = False):
......
......@@ -19,8 +19,6 @@ TIMEOUT_DURATION = 3
@sio.event
async def connect(sid, environ, auth):
print("connect ", sid)
user = None
if auth and "token" in auth:
data = decode_token(auth["token"])
......@@ -37,7 +35,6 @@ async def connect(sid, environ, auth):
print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(len(set(USER_POOL)))
await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("usage", {"models": get_models_in_use()})
......@@ -64,13 +61,11 @@ async def user_join(sid, data):
print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(len(set(USER_POOL)))
await sio.emit("user-count", {"count": len(set(USER_POOL))})
@sio.on("user-count")
async def user_count(sid):
print("user-count", sid)
await sio.emit("user-count", {"count": len(set(USER_POOL))})
......@@ -79,14 +74,12 @@ def get_models_in_use():
models_in_use = []
for model_id, data in USAGE_POOL.items():
models_in_use.append(model_id)
print(f"Models in use: {models_in_use}")
return models_in_use
@sio.on("usage")
async def usage(sid, data):
print(f'Received "usage" event from {sid}: {data}')
model_id = data["model"]
......@@ -114,7 +107,6 @@ async def usage(sid, data):
async def remove_after_timeout(sid, model_id):
try:
print("remove_after_timeout", sid, model_id)
await asyncio.sleep(TIMEOUT_DURATION)
if model_id in USAGE_POOL:
print(USAGE_POOL[model_id]["sids"])
......@@ -124,7 +116,6 @@ async def remove_after_timeout(sid, model_id):
if len(USAGE_POOL[model_id]["sids"]) == 0:
del USAGE_POOL[model_id]
print(f"Removed usage data for {model_id} due to timeout")
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
except asyncio.CancelledError:
......@@ -143,9 +134,6 @@ async def disconnect(sid):
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
print(f"user {user_id} disconnected with session ID {sid}")
print(USER_POOL)
await sio.emit("user-count", {"count": len(USER_POOL)})
else:
print(f"Unknown session ID {sid} disconnected")
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Tool(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
content = pw.TextField()
specs = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "tool"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("tool")
......@@ -6,6 +6,7 @@ from apps.webui.routers import (
users,
chats,
documents,
tools,
models,
prompts,
configs,
......@@ -26,8 +27,8 @@ from config import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
WEBUI_BANNERS,
AppConfig,
ENABLE_COMMUNITY_SHARING,
AppConfig,
)
app = FastAPI()
......@@ -38,6 +39,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
......@@ -54,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.TOOLS = {}
app.add_middleware(
......@@ -70,6 +72,7 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
......
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tools DB Schema
####################
class Tool(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class ToolMeta(BaseModel):
description: Optional[str] = None
class ToolModel(BaseModel):
id: str
user_id: str
name: str
content: str
specs: List[dict]
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class ToolResponse(BaseModel):
id: str
user_id: str
name: str
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ToolForm(BaseModel):
id: str
name: str
content: str
meta: ToolMeta
class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool.create(**tool.model_dump())
if result:
return tool
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
**updated,
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def delete_tool_by_id(self, id: str) -> bool:
try:
query = Tool.delete().where((Tool.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Tools = ToolsTable(DB)
......@@ -161,7 +161,7 @@ async def get_archived_session_user_chat_list(
############################
@router.post("/archive/all", response_model=List[ChatTitleIdResponse])
@router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(user.id)
......
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
from config import DATA_DIR
TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter()
############################
# GetToolkits
############################
@router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_current_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# ExportToolKits
############################
@router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# CreateNewToolKit
############################
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(form_data.id)
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetToolkitById
############################
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolkitById
############################
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(id)
TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module
specs = get_tools_specs(TOOLS[id])
updated = {
**form_data.model_dump(exclude={"id"}),
"specs": specs,
}
print(updated)
toolkit = Tools.update_tool_by_id(id, updated)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteToolkitById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
result = Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
if id in TOOLS:
del TOOLS[id]
# delete the toolkit file
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
os.remove(toolkit_path)
return result
......@@ -7,6 +7,8 @@ from pydantic import BaseModel
from fpdf import FPDF
import markdown
import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
......@@ -26,6 +28,21 @@ async def get_gravatar(
return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
class MarkdownForm(BaseModel):
md: str
......
from importlib import util
import os
from config import TOOLS_DIR
def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
return module.Tools()
else:
raise Exception("No Tools class found")
except Exception as e:
print(f"Error loading module: {toolkit_id}")
# Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error")
raise e
......@@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Tools DIR
####################################
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# LITELLM_CONFIG
####################################
......@@ -669,16 +677,28 @@ Question:
),
)
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
"task.search.prompt_length_threshold",
int(
os.environ.get(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
100,
)
),
)
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"task.tools.prompt_template",
os.environ.get(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
100,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"""Tools: {{TOOLS}}
If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""",
),
)
####################################
# WEBUI_SECRET_KEY
####################################
......
......@@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum):
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
......
......@@ -11,6 +11,7 @@ import requests
import mimetypes
import shutil
import os
import inspect
import asyncio
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
......@@ -47,15 +48,24 @@ from pydantic import BaseModel
from typing import List, Optional
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import (
get_admin_user,
get_verified_user,
get_current_user,
get_http_authorization_cred,
)
from utils.task import title_generation_template, search_query_generation_template
from utils.task import (
title_generation_template,
search_query_generation_template,
tools_function_calling_generation_template,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from apps.rag.utils import rag_messages
from apps.rag.utils import get_rag_context, rag_template
from config import (
CONFIG_DATA,
......@@ -82,6 +92,7 @@ from config import (
TITLE_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
AppConfig,
)
from constants import ERROR_MESSAGES
......@@ -148,24 +159,117 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
app.state.MODELS = {}
origins = ["*"]
# Custom middleware to add security headers
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next):
# response: Response = await call_next(request)
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
# return response
async def get_function_call_response(messages, tool_id, template, task_model_id, user):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
user_message = get_last_user_message(messages)
prompt = (
"History:\n"
+ "\n".join(
[
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
]
)
+ f"\nQuery: {user_message}"
)
print(prompt)
payload = {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
}
payload = filter_pipeline(payload, user)
model = app.state.MODELS[task_model_id]
response = None
try:
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
# Parse the function response
if content is not None:
print(f"content: {content}")
result = json.loads(content)
print(result)
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
# Check if '__user__' is a parameter of the function
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
function_result = function(
**{
**result["parameters"],
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
}
)
else:
# Call the function without modifying the parameters
function_result = function(**result["parameters"])
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result:
return function_result
except Exception as e:
print(f"Error: {e}")
# app.add_middleware(SecurityHeadersMiddleware)
return None
class RAGMiddleware(BaseHTTPMiddleware):
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
return_citations = False
......@@ -182,35 +286,95 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
# Remove the citations from the body
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
# Set the task model
task_model_id = data["model"]
if task_model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"])
context = ""
# If tool_ids field is present, call the functions
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
print(tool_id)
response = await get_function_call_response(
messages=data["messages"],
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
if response:
context += ("\n" if context != "" else "") + response
del data["tool_ids"]
print(f"tool_context: {context}")
# If docs field is present, generate RAG completions
if "docs" in data:
data = {**data}
data["messages"], citations = rag_messages(
rag_context, citations = get_rag_context(
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.config.RAG_TEMPLATE,
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
del data["docs"]
log.debug(
f"data['messages']: {data['messages']}, citations: {citations}"
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
......@@ -253,7 +417,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
yield data
app.add_middleware(RAGMiddleware)
app.add_middleware(ChatCompletionMiddleware)
def filter_pipeline(payload, user):
......@@ -515,6 +679,7 @@ async def get_task_config(user=Depends(get_verified_user)):
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
......@@ -524,6 +689,7 @@ class TaskConfigForm(BaseModel):
TITLE_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@app.post("/api/task/config/update")
......@@ -539,6 +705,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
return {
"TASK_MODEL": app.state.config.TASK_MODEL,
......@@ -546,6 +715,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
......@@ -659,6 +829,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/tools/completions")
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
print("get_tools_function_calling")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
return await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user
)
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
......
......@@ -59,4 +59,5 @@ youtube-transcript-api==0.6.2
pytube==15.0.0
extract_msg
pydub
\ No newline at end of file
pydub
duckduckgo-search~=6.1.5
\ No newline at end of file
......@@ -8,6 +8,7 @@ cd /d "%SCRIPT_DIR%" || exit /b
SET "KEY_FILE=.webui_secret_key"
IF "%PORT%"=="" SET PORT=8080
IF "%HOST%"=="" SET HOST=0.0.0.0
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
......@@ -29,4 +30,4 @@ IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
:: Execute uvicorn
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
uvicorn main:app --host 0.0.0.0 --port "%PORT%" --forwarded-allow-ips '*'
uvicorn main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*'
......@@ -110,3 +110,8 @@ def search_query_generation_template(
),
)
return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs)
return template
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment