Unverified Commit c41b33c9 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3013 from open-webui/dev

0.3.3
parents 06976c45 7131ac24
...@@ -5,6 +5,20 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,20 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.3] - 2024-06-12
### Added
- **🛠️ Native Python Function Calling**: Introducing native Python function calling within Open WebUI. We’ve also included a built-in code editor to seamlessly develop and integrate function code within the 'Tools' workspace. With this, you can significantly enhance your LLM’s capabilities by creating custom RAG pipelines, web search tools, and even agent-like features such as sending Discord messages.
- **🌐 DuckDuckGo Integration**: Added DuckDuckGo as a web search provider, giving you more search options.
- **🌏 Enhanced Translations**: Improved translations for Vietnamese and Chinese languages, making the interface more accessible.
### Fixed
- **🔗 Web Search URL Error Handling**: Fixed the issue where a single URL error would disrupt the data loading process in Web Search mode. Now, such errors will be handled gracefully to ensure uninterrupted data loading.
- **🖥️ Frontend Responsiveness**: Resolved the problem where the frontend would stop responding if the backend encounters an error while downloading a model. Improved error handling to maintain frontend stability.
- **🔧 Dependency Issues in pip**: Fixed issues related to pip installations, ensuring all dependencies are correctly managed to prevent installation errors.
## [0.3.2] - 2024-06-10 ## [0.3.2] - 2024-06-10
### Added ### Added
......
...@@ -29,8 +29,12 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature- ...@@ -29,8 +29,12 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction. - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment.
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration. - 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query. - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, and `Serply` and inject the results directly into your chat experience. - 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, and `Serply` and inject the results directly into your chat experience.
......
...@@ -8,13 +8,15 @@ from fastapi import ( ...@@ -8,13 +8,15 @@ from fastapi import (
Form, Form,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import requests
import os, shutil, logging, re import os, shutil, logging, re
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Union, Sequence from typing import List, Union, Sequence, Iterator, Any
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
from langchain_core.documents import Document
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
WebBaseLoader, WebBaseLoader,
...@@ -70,6 +72,7 @@ from apps.rag.search.searxng import search_searxng ...@@ -70,6 +72,7 @@ from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo
from utils.misc import ( from utils.misc import (
calculate_sha256, calculate_sha256,
...@@ -701,7 +704,7 @@ def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True): ...@@ -701,7 +704,7 @@ def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
# Check if the URL is valid # Check if the URL is valid
if not validate_url(url): if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader( return SafeWebBaseLoader(
url, url,
verify_ssl=verify_ssl, verify_ssl=verify_ssl,
requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS, requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
...@@ -714,18 +717,13 @@ def validate_url(url: Union[str, Sequence[str]]): ...@@ -714,18 +717,13 @@ def validate_url(url: Union[str, Sequence[str]]):
if isinstance(validators.url(url), validators.ValidationError): if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_RAG_LOCAL_WEB_FETCH: if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses # Check if the URL exists by making a HEAD request
parsed_url = urllib.parse.urlparse(url) try:
# Get IPv4 and IPv6 addresses response = requests.head(url, allow_redirects=True)
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname) if response.status_code != 200:
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
except requests.exceptions.RequestException:
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return True return True
elif isinstance(url, Sequence): elif isinstance(url, Sequence):
return all(validate_url(u) for u in url) return all(validate_url(u) for u in url)
...@@ -733,17 +731,6 @@ def validate_url(url: Union[str, Sequence[str]]): ...@@ -733,17 +731,6 @@ def validate_url(url: Union[str, Sequence[str]]):
return False return False
def resolve_hostname(hostname):
# Get address information
addr_info = socket.getaddrinfo(hostname, None)
# Extract IP addresses from address information
ipv4_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET]
ipv6_addresses = [info[4][0] for info in addr_info if info[0] == socket.AF_INET6]
return ipv4_addresses, ipv6_addresses
def search_web(engine: str, query: str) -> list[SearchResult]: def search_web(engine: str, query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects. """Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order: Will look for a search engine API key in environment variables in the following order:
...@@ -820,6 +807,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -820,6 +807,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
) )
else: else:
raise Exception("No SERPLY_API_KEY found in environment variables") raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
else: else:
raise Exception("No search engine API key found in environment variables") raise Exception("No search engine API key found in environment variables")
...@@ -827,7 +816,9 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -827,7 +816,9 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
@app.post("/web/search") @app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)): def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
try: try:
logging.info(f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}") logging.info(
f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
)
web_results = search_web( web_results = search_web(
app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
) )
...@@ -1238,6 +1229,33 @@ def reset(user=Depends(get_admin_user)) -> bool: ...@@ -1238,6 +1229,33 @@ def reset(user=Depends(get_admin_user)) -> bool:
return True return True
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths:
try:
soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
text = soup.get_text(**self.bs_get_text_kwargs)
# Build metadata
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.error(f"Error loading {path}: {e}")
if ENV == "dev": if ENV == "dev":
@app.get("/ef") @app.get("/ef")
......
import logging
from apps.rag.search.main import SearchResult
from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args:
query (str): The query to search for
count (int): The number of results to return
Returns:
List[SearchResult]: A list of search results
"""
# Use the DDGS context manager to create a DDGS object
with DDGS() as ddgs:
# Use the ddgs.text() method to perform the search
ddgs_gen = ddgs.text(
query, safesearch="moderate", max_results=count, backend="api"
)
# Check if there are search results
if ddgs_gen:
# Convert the search results into a list
search_results = [r for r in ddgs_gen]
# Create an empty list to store the SearchResult objects
results = []
# Iterate over each search result
for result in search_results:
# Create a SearchResult object and append it to the results list
results.append(
SearchResult(
link=result["href"],
title=result.get("title"),
snippet=result.get("body"),
)
)
print(results)
# Return the list of search results
return results
...@@ -12,14 +12,14 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -12,14 +12,14 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serply( def search_serply(
api_key: str, api_key: str,
query: str, query: str,
count: int, count: int,
hl: str = "us", hl: str = "us",
limit: int = 10, limit: int = 10,
device_type: str = "desktop", device_type: str = "desktop",
proxy_location: str = "US" proxy_location: str = "US",
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
Args: Args:
...@@ -37,7 +37,7 @@ def search_serply( ...@@ -37,7 +37,7 @@ def search_serply(
"language": "en", "language": "en",
"num": limit, "num": limit,
"gl": proxy_location.upper(), "gl": proxy_location.upper(),
"hl": hl.lower() "hl": hl.lower(),
} }
url = f"{url}{urlencode(query_payload)}" url = f"{url}{urlencode(query_payload)}"
...@@ -45,7 +45,7 @@ def search_serply( ...@@ -45,7 +45,7 @@ def search_serply(
"X-API-KEY": api_key, "X-API-KEY": api_key,
"X-User-Agent": device_type, "X-User-Agent": device_type,
"User-Agent": "open-webui", "User-Agent": "open-webui",
"X-Proxy-Location": proxy_location "X-Proxy-Location": proxy_location,
} }
response = requests.request("GET", url, headers=headers) response = requests.request("GET", url, headers=headers)
......
...@@ -236,10 +236,9 @@ def get_embedding_function( ...@@ -236,10 +236,9 @@ def get_embedding_function(
return lambda query: generate_multiple(query, func) return lambda query: generate_multiple(query, func)
def rag_messages( def get_rag_context(
docs, docs,
messages, messages,
template,
embedding_function, embedding_function,
k, k,
reranking_function, reranking_function,
...@@ -318,16 +317,7 @@ def rag_messages( ...@@ -318,16 +317,7 @@ def rag_messages(
context_string = context_string.strip() context_string = context_string.strip()
ra_content = rag_template( return context_string, citations
template=template,
context=context_string,
query=query,
)
log.debug(f"ra_content: {ra_content}")
messages = add_or_update_system_message(ra_content, messages)
return messages, citations
def get_model_path(model: str, update_model: bool = False): def get_model_path(model: str, update_model: bool = False):
......
...@@ -19,8 +19,6 @@ TIMEOUT_DURATION = 3 ...@@ -19,8 +19,6 @@ TIMEOUT_DURATION = 3
@sio.event @sio.event
async def connect(sid, environ, auth): async def connect(sid, environ, auth):
print("connect ", sid)
user = None user = None
if auth and "token" in auth: if auth and "token" in auth:
data = decode_token(auth["token"]) data = decode_token(auth["token"])
...@@ -37,7 +35,6 @@ async def connect(sid, environ, auth): ...@@ -37,7 +35,6 @@ async def connect(sid, environ, auth):
print(f"user {user.name}({user.id}) connected with session ID {sid}") print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(len(set(USER_POOL)))
await sio.emit("user-count", {"count": len(set(USER_POOL))}) await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("usage", {"models": get_models_in_use()}) await sio.emit("usage", {"models": get_models_in_use()})
...@@ -64,13 +61,11 @@ async def user_join(sid, data): ...@@ -64,13 +61,11 @@ async def user_join(sid, data):
print(f"user {user.name}({user.id}) connected with session ID {sid}") print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(len(set(USER_POOL)))
await sio.emit("user-count", {"count": len(set(USER_POOL))}) await sio.emit("user-count", {"count": len(set(USER_POOL))})
@sio.on("user-count") @sio.on("user-count")
async def user_count(sid): async def user_count(sid):
print("user-count", sid)
await sio.emit("user-count", {"count": len(set(USER_POOL))}) await sio.emit("user-count", {"count": len(set(USER_POOL))})
...@@ -79,14 +74,12 @@ def get_models_in_use(): ...@@ -79,14 +74,12 @@ def get_models_in_use():
models_in_use = [] models_in_use = []
for model_id, data in USAGE_POOL.items(): for model_id, data in USAGE_POOL.items():
models_in_use.append(model_id) models_in_use.append(model_id)
print(f"Models in use: {models_in_use}")
return models_in_use return models_in_use
@sio.on("usage") @sio.on("usage")
async def usage(sid, data): async def usage(sid, data):
print(f'Received "usage" event from {sid}: {data}')
model_id = data["model"] model_id = data["model"]
...@@ -114,7 +107,6 @@ async def usage(sid, data): ...@@ -114,7 +107,6 @@ async def usage(sid, data):
async def remove_after_timeout(sid, model_id): async def remove_after_timeout(sid, model_id):
try: try:
print("remove_after_timeout", sid, model_id)
await asyncio.sleep(TIMEOUT_DURATION) await asyncio.sleep(TIMEOUT_DURATION)
if model_id in USAGE_POOL: if model_id in USAGE_POOL:
print(USAGE_POOL[model_id]["sids"]) print(USAGE_POOL[model_id]["sids"])
...@@ -124,7 +116,6 @@ async def remove_after_timeout(sid, model_id): ...@@ -124,7 +116,6 @@ async def remove_after_timeout(sid, model_id):
if len(USAGE_POOL[model_id]["sids"]) == 0: if len(USAGE_POOL[model_id]["sids"]) == 0:
del USAGE_POOL[model_id] del USAGE_POOL[model_id]
print(f"Removed usage data for {model_id} due to timeout")
# Broadcast the usage data to all clients # Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()}) await sio.emit("usage", {"models": get_models_in_use()})
except asyncio.CancelledError: except asyncio.CancelledError:
...@@ -143,9 +134,6 @@ async def disconnect(sid): ...@@ -143,9 +134,6 @@ async def disconnect(sid):
if len(USER_POOL[user_id]) == 0: if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id] del USER_POOL[user_id]
print(f"user {user_id} disconnected with session ID {sid}")
print(USER_POOL)
await sio.emit("user-count", {"count": len(USER_POOL)}) await sio.emit("user-count", {"count": len(USER_POOL)})
else: else:
print(f"Unknown session ID {sid} disconnected") print(f"Unknown session ID {sid} disconnected")
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Tool(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
content = pw.TextField()
specs = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "tool"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("tool")
...@@ -6,6 +6,7 @@ from apps.webui.routers import ( ...@@ -6,6 +6,7 @@ from apps.webui.routers import (
users, users,
chats, chats,
documents, documents,
tools,
models, models,
prompts, prompts,
configs, configs,
...@@ -26,8 +27,8 @@ from config import ( ...@@ -26,8 +27,8 @@ from config import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN, JWT_EXPIRES_IN,
WEBUI_BANNERS, WEBUI_BANNERS,
AppConfig,
ENABLE_COMMUNITY_SHARING, ENABLE_COMMUNITY_SHARING,
AppConfig,
) )
app = FastAPI() app = FastAPI()
...@@ -38,6 +39,7 @@ app.state.config = AppConfig() ...@@ -38,6 +39,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
...@@ -54,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS ...@@ -54,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {} app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.TOOLS = {}
app.add_middleware( app.add_middleware(
...@@ -70,6 +72,7 @@ app.include_router(users.router, prefix="/users", tags=["users"]) ...@@ -70,6 +72,7 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"]) app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(memories.router, prefix="/memories", tags=["memories"])
......
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tools DB Schema
####################
class Tool(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class ToolMeta(BaseModel):
description: Optional[str] = None
class ToolModel(BaseModel):
id: str
user_id: str
name: str
content: str
specs: List[dict]
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class ToolResponse(BaseModel):
id: str
user_id: str
name: str
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ToolForm(BaseModel):
id: str
name: str
content: str
meta: ToolMeta
class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool.create(**tool.model_dump())
if result:
return tool
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
**updated,
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def delete_tool_by_id(self, id: str) -> bool:
try:
query = Tool.delete().where((Tool.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Tools = ToolsTable(DB)
...@@ -161,7 +161,7 @@ async def get_archived_session_user_chat_list( ...@@ -161,7 +161,7 @@ async def get_archived_session_user_chat_list(
############################ ############################
@router.post("/archive/all", response_model=List[ChatTitleIdResponse]) @router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_current_user)): async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(user.id) return Chats.archive_all_chats_by_user_id(user.id)
......
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
from config import DATA_DIR
TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter()
############################
# GetToolkits
############################
@router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_current_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# ExportToolKits
############################
@router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# CreateNewToolKit
############################
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(form_data.id)
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetToolkitById
############################
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolkitById
############################
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(id)
TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module
specs = get_tools_specs(TOOLS[id])
updated = {
**form_data.model_dump(exclude={"id"}),
"specs": specs,
}
print(updated)
toolkit = Tools.update_tool_by_id(id, updated)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteToolkitById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
result = Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
if id in TOOLS:
del TOOLS[id]
# delete the toolkit file
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
os.remove(toolkit_path)
return result
...@@ -7,6 +7,8 @@ from pydantic import BaseModel ...@@ -7,6 +7,8 @@ from pydantic import BaseModel
from fpdf import FPDF from fpdf import FPDF
import markdown import markdown
import black
from apps.webui.internal.db import DB from apps.webui.internal.db import DB
from utils.utils import get_admin_user from utils.utils import get_admin_user
...@@ -26,6 +28,21 @@ async def get_gravatar( ...@@ -26,6 +28,21 @@ async def get_gravatar(
return get_gravatar_url(email) return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
class MarkdownForm(BaseModel): class MarkdownForm(BaseModel):
md: str md: str
......
from importlib import util
import os
from config import TOOLS_DIR
def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
return module.Tools()
else:
raise Exception("No Tools class found")
except Exception as e:
print(f"Error loading module: {toolkit_id}")
# Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error")
raise e
...@@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs") ...@@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True) Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Tools DIR
####################################
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# LITELLM_CONFIG # LITELLM_CONFIG
#################################### ####################################
...@@ -669,16 +677,28 @@ Question: ...@@ -669,16 +677,28 @@ Question:
), ),
) )
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig( SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
"task.search.prompt_length_threshold", "task.search.prompt_length_threshold",
int(
os.environ.get(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
100,
)
),
)
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"task.tools.prompt_template",
os.environ.get( os.environ.get(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD", "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
100, """Tools: {{TOOLS}}
If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""",
), ),
) )
#################################### ####################################
# WEBUI_SECRET_KEY # WEBUI_SECRET_KEY
#################################### ####################################
......
...@@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum): ...@@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum):
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file." FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
......
...@@ -11,6 +11,7 @@ import requests ...@@ -11,6 +11,7 @@ import requests
import mimetypes import mimetypes
import shutil import shutil
import os import os
import inspect
import asyncio import asyncio
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
...@@ -47,15 +48,24 @@ from pydantic import BaseModel ...@@ -47,15 +48,24 @@ from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional
from apps.webui.models.models import Models, ModelModel from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import ( from utils.utils import (
get_admin_user, get_admin_user,
get_verified_user, get_verified_user,
get_current_user, get_current_user,
get_http_authorization_cred, get_http_authorization_cred,
) )
from utils.task import title_generation_template, search_query_generation_template from utils.task import (
title_generation_template,
search_query_generation_template,
tools_function_calling_generation_template,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from apps.rag.utils import rag_messages from apps.rag.utils import get_rag_context, rag_template
from config import ( from config import (
CONFIG_DATA, CONFIG_DATA,
...@@ -82,6 +92,7 @@ from config import ( ...@@ -82,6 +92,7 @@ from config import (
TITLE_GENERATION_PROMPT_TEMPLATE, TITLE_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
AppConfig, AppConfig,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -148,24 +159,117 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( ...@@ -148,24 +159,117 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
) )
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
app.state.MODELS = {} app.state.MODELS = {}
origins = ["*"] origins = ["*"]
# Custom middleware to add security headers
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next):
# response: Response = await call_next(request)
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
# return response
async def get_function_call_response(messages, tool_id, template, task_model_id, user):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
user_message = get_last_user_message(messages)
prompt = (
"History:\n"
+ "\n".join(
[
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
]
)
+ f"\nQuery: {user_message}"
)
print(prompt)
payload = {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
}
payload = filter_pipeline(payload, user)
model = app.state.MODELS[task_model_id]
response = None
try:
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
# Parse the function response
if content is not None:
print(f"content: {content}")
result = json.loads(content)
print(result)
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
# Check if '__user__' is a parameter of the function
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
function_result = function(
**{
**result["parameters"],
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
}
)
else:
# Call the function without modifying the parameters
function_result = function(**result["parameters"])
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result:
return function_result
except Exception as e:
print(f"Error: {e}")
# app.add_middleware(SecurityHeadersMiddleware) return None
class RAGMiddleware(BaseHTTPMiddleware): class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
return_citations = False return_citations = False
...@@ -182,35 +286,95 @@ class RAGMiddleware(BaseHTTPMiddleware): ...@@ -182,35 +286,95 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON # Parse string to JSON
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
# Remove the citations from the body
return_citations = data.get("citations", False) return_citations = data.get("citations", False)
if "citations" in data: if "citations" in data:
del data["citations"] del data["citations"]
# Example: Add a new key-value pair or modify existing ones # Set the task model
# data["modified"] = True # Example modification task_model_id = data["model"]
if task_model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"])
context = ""
# If tool_ids field is present, call the functions
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
print(tool_id)
response = await get_function_call_response(
messages=data["messages"],
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
if response:
context += ("\n" if context != "" else "") + response
del data["tool_ids"]
print(f"tool_context: {context}")
# If docs field is present, generate RAG completions
if "docs" in data: if "docs" in data:
data = {**data} data = {**data}
data["messages"], citations = rag_messages( rag_context, citations = get_rag_context(
docs=data["docs"], docs=data["docs"],
messages=data["messages"], messages=data["messages"],
template=rag_app.state.config.RAG_TEMPLATE,
embedding_function=rag_app.state.EMBEDDING_FUNCTION, embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K, k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf, reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD, r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
) )
if rag_context:
context += ("\n" if context != "" else "") + rag_context
del data["docs"] del data["docs"]
log.debug( log.debug(f"rag_context: {rag_context}, citations: {citations}")
f"data['messages']: {data['messages']}, citations: {citations}"
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
) )
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
request._body = modified_body_bytes request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length # Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [ request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")), (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
...@@ -253,7 +417,7 @@ class RAGMiddleware(BaseHTTPMiddleware): ...@@ -253,7 +417,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
yield data yield data
app.add_middleware(RAGMiddleware) app.add_middleware(ChatCompletionMiddleware)
def filter_pipeline(payload, user): def filter_pipeline(payload, user):
...@@ -515,6 +679,7 @@ async def get_task_config(user=Depends(get_verified_user)): ...@@ -515,6 +679,7 @@ async def get_task_config(user=Depends(get_verified_user)):
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
} }
...@@ -524,6 +689,7 @@ class TaskConfigForm(BaseModel): ...@@ -524,6 +689,7 @@ class TaskConfigForm(BaseModel):
TITLE_GENERATION_PROMPT_TEMPLATE: str TITLE_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@app.post("/api/task/config/update") @app.post("/api/task/config/update")
...@@ -539,6 +705,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u ...@@ -539,6 +705,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = ( app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
) )
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
return { return {
"TASK_MODEL": app.state.config.TASK_MODEL, "TASK_MODEL": app.state.config.TASK_MODEL,
...@@ -546,6 +715,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u ...@@ -546,6 +715,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE, "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD, "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
} }
...@@ -659,6 +829,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -659,6 +829,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
return await generate_openai_chat_completion(payload, user=user) return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/tools/completions")
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
print("get_tools_function_calling")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
return await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user
)
@app.post("/api/chat/completions") @app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"] model_id = form_data["model"]
......
...@@ -59,4 +59,5 @@ youtube-transcript-api==0.6.2 ...@@ -59,4 +59,5 @@ youtube-transcript-api==0.6.2
pytube==15.0.0 pytube==15.0.0
extract_msg extract_msg
pydub pydub
\ No newline at end of file duckduckgo-search~=6.1.5
\ No newline at end of file
...@@ -8,6 +8,7 @@ cd /d "%SCRIPT_DIR%" || exit /b ...@@ -8,6 +8,7 @@ cd /d "%SCRIPT_DIR%" || exit /b
SET "KEY_FILE=.webui_secret_key" SET "KEY_FILE=.webui_secret_key"
IF "%PORT%"=="" SET PORT=8080 IF "%PORT%"=="" SET PORT=8080
IF "%HOST%"=="" SET HOST=0.0.0.0
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%" SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%" SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
...@@ -29,4 +30,4 @@ IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " ( ...@@ -29,4 +30,4 @@ IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
:: Execute uvicorn :: Execute uvicorn
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%" SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
uvicorn main:app --host 0.0.0.0 --port "%PORT%" --forwarded-allow-ips '*' uvicorn main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*'
...@@ -110,3 +110,8 @@ def search_query_generation_template( ...@@ -110,3 +110,8 @@ def search_query_generation_template(
), ),
) )
return template return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs)
return template
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment