Unverified Commit 1eebb85f authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3323 from open-webui/dev

0.3.6
parents 9e4dd4b8 b224ba00
......@@ -43,7 +43,7 @@ jobs:
uses: cypress-io/github-action@v6
with:
browser: chrome
wait-on: 'http://localhost:3000'
wait-on: "http://localhost:3000"
config: baseUrl=http://localhost:3000
- uses: actions/upload-artifact@v4
......@@ -82,18 +82,18 @@ jobs:
--health-retries 5
ports:
- 5432:5432
# mysql:
# image: mysql
# env:
# MYSQL_ROOT_PASSWORD: mysql
# MYSQL_DATABASE: mysql
# options: >-
# --health-cmd "mysqladmin ping -h localhost"
# --health-interval 10s
# --health-timeout 5s
# --health-retries 5
# ports:
# - 3306:3306
# mysql:
# image: mysql
# env:
# MYSQL_ROOT_PASSWORD: mysql
# MYSQL_DATABASE: mysql
# options: >-
# --health-cmd "mysqladmin ping -h localhost"
# --health-interval 10s
# --health-timeout 5s
# --health-retries 5
# ports:
# - 3306:3306
steps:
- name: Checkout Repository
uses: actions/checkout@v4
......@@ -143,7 +143,6 @@ jobs:
exit 1
fi
- name: Test backend with Postgres
if: success() || steps.sqlite.conclusion == 'failure'
env:
......@@ -171,6 +170,25 @@ jobs:
exit 1
fi
# Check that service will reconnect to postgres when connection will be closed
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has failed before postgres reconnect check"
exit 1
fi
echo "Terminating all connections to postgres..."
python -c "import os, psycopg2 as pg2; \
conn = pg2.connect(dsn=os.environ['DATABASE_URL'].replace('+pool', '')); \
cur = conn.cursor(); \
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
exit 1
fi
# - name: Test backend with MySQL
# if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
# env:
......
......@@ -5,6 +5,36 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.6] - 2024-06-27
### Added
- **✨ "Functions" Feature**: You can now utilize "Functions" like filters (middleware) and pipe (model) functions directly within the WebUI. While largely compatible with Pipelines, these native functions can be executed easily within Open WebUI. Example use cases for filter functions include usage monitoring, real-time translation, moderation, and automemory. For pipe functions, the scope ranges from Cohere and Anthropic integration directly within Open WebUI, enabling "Valves" for per-user OpenAI API key usage, and much more. If you encounter issues, SAFE_MODE has been introduced.
- **📁 Files API**: Compatible with OpenAI, this feature allows for custom Retrieval-Augmented Generation (RAG) in conjunction with the Filter Function. More examples will be shared on our community platform and official documentation website.
- **🛠️ Tool Enhancements**: Tools now support citations and "Valves". Documentation will be available shortly.
- **🔗 Iframe Support via Files API**: Enables rendering HTML directly into your chat interface using functions and tools. Use cases include playing games like DOOM and Snake, displaying a weather applet, and implementing Anthropic "artifacts"-like features. Stay tuned for updates on our community platform and documentation.
- **🔒 Experimental OAuth Support**: New experimental OAuth support. Check our documentation for more details.
- **🖼️ Custom Background Support**: Set a custom background from Settings > Interface to personalize your experience.
- **🔑 AUTOMATIC1111_API_AUTH Support**: Enhanced security for the AUTOMATIC1111 API.
- **🎨 Code Highlight Optimization**: Improved code highlighting features.
- **🎙️ Voice Interruption Feature**: Reintroduced and now toggleable from Settings > Interface.
- **💤 Wakelock API**: Now in use to prevent screen dimming during important tasks.
- **🔐 API Key Privacy**: All API keys are now hidden by default for better security.
- **🔍 New Web Search Provider**: Added jina_search as a new option.
- **🌐 Enhanced Internationalization (i18n)**: Improved Korean translation and updated Chinese and Ukrainian translations.
### Fixed
- **🔧 Conversation Mode Issue**: Fixed the issue where Conversation Mode remained active after being removed from settings.
- **📏 Scroll Button Obstruction**: Resolved the issue where the scrollToBottom button container obstructed clicks on buttons beneath it.
### Changed
- **⏲️ AIOHTTP_CLIENT_TIMEOUT**: Now set to `None` by default for improved configuration flexibility.
- **📞 Voice Call Enhancements**: Improved by skipping code blocks and expressions during calls.
- **🚫 Error Message Handling**: Disabled the continuation of operations with error messages.
- **🗂️ Playground Relocation**: Moved the Playground from the workspace to the user menu for better user experience.
## [0.3.5] - 2024-06-16
### Added
......
......@@ -325,7 +325,7 @@ def transcribe(
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
files = {"file": (filename, open(file_path, "rb"))}
data = {"model": "whisper-1"}
data = {"model": app.state.config.STT_MODEL}
print(files, data)
......
import re
import requests
import base64
from fastapi import (
FastAPI,
Request,
......@@ -15,7 +16,7 @@ from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
get_current_user,
get_verified_user,
get_admin_user,
)
......@@ -36,6 +37,7 @@ from config import (
IMAGE_GENERATION_ENGINE,
ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL,
AUTOMATIC1111_API_AUTH,
COMFYUI_BASE_URL,
COMFYUI_CFG_SCALE,
COMFYUI_SAMPLER,
......@@ -49,7 +51,6 @@ from config import (
AppConfig,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
......@@ -75,11 +76,10 @@ app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.MODEL = IMAGE_GENERATION_MODEL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
......@@ -88,6 +88,16 @@ app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
def get_automatic1111_api_auth():
if app.state.config.AUTOMATIC1111_API_AUTH == None:
return ""
else:
auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
return f"Basic {auth1111_base64_encoded_string}"
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
......@@ -113,6 +123,7 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
class EngineUrlUpdateForm(BaseModel):
AUTOMATIC1111_BASE_URL: Optional[str] = None
AUTOMATIC1111_API_AUTH: Optional[str] = None
COMFYUI_BASE_URL: Optional[str] = None
......@@ -120,6 +131,7 @@ class EngineUrlUpdateForm(BaseModel):
async def get_engine_url(user=Depends(get_admin_user)):
return {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
}
......@@ -128,7 +140,6 @@ async def get_engine_url(user=Depends(get_admin_user)):
async def update_engine_url(
form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
):
if form_data.AUTOMATIC1111_BASE_URL == None:
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else:
......@@ -150,8 +161,14 @@ async def update_engine_url(
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.AUTOMATIC1111_API_AUTH == None:
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
else:
app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
return {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"status": True,
}
......@@ -241,7 +258,7 @@ async def update_image_size(
@app.get("/models")
def get_models(user=Depends(get_current_user)):
def get_models(user=Depends(get_verified_user)):
try:
if app.state.config.ENGINE == "openai":
return [
......@@ -262,7 +279,8 @@ def get_models(user=Depends(get_current_user)):
else:
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
headers={"authorization": get_automatic1111_api_auth()},
)
models = r.json()
return list(
......@@ -289,7 +307,8 @@ async def get_default_model(user=Depends(get_admin_user)):
return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
else:
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth()},
)
options = r.json()
return {"model": options["sd_model_checkpoint"]}
......@@ -307,8 +326,10 @@ def set_model_handler(model: str):
app.state.config.MODEL = model
return app.state.config.MODEL
else:
api_auth = get_automatic1111_api_auth()
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": api_auth},
)
options = r.json()
......@@ -317,6 +338,7 @@ def set_model_handler(model: str):
r = requests.post(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
headers={"authorization": api_auth},
)
return options
......@@ -325,7 +347,7 @@ def set_model_handler(model: str):
@app.post("/models/default/update")
def update_default_model(
form_data: UpdateModelForm,
user=Depends(get_current_user),
user=Depends(get_verified_user),
):
return set_model_handler(form_data.model)
......@@ -402,9 +424,8 @@ def save_url_image(url):
@app.post("/generations")
def generate_image(
form_data: GenerateImageForm,
user=Depends(get_current_user),
user=Depends(get_verified_user),
):
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
r = None
......@@ -519,6 +540,7 @@ def generate_image(
r = requests.post(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
headers={"authorization": get_automatic1111_api_auth()},
)
res = r.json()
......
......@@ -40,6 +40,7 @@ from utils.utils import (
get_verified_user,
get_admin_user,
)
from utils.task import prompt_template
from config import (
......@@ -52,7 +53,7 @@ from config import (
UPLOAD_DIR,
AppConfig,
)
from utils.misc import calculate_sha256
from utils.misc import calculate_sha256, add_or_update_system_message
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
......@@ -199,9 +200,6 @@ def merge_models_lists(model_lists):
return list(merged_models.values())
# user=Depends(get_current_user)
async def get_all_models():
log.info("get_all_models()")
......@@ -817,23 +815,27 @@ async def generate_chat_completion(
"num_thread", None
)
if model_info.params.get("system", None):
system = model_info.params.get("system", None)
if system:
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
system = prompt_template(
system,
**(
{
"role": "system",
"content": model_info.params.get("system", None),
},
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
if payload.get("messages"):
payload["messages"] = add_or_update_system_message(
system, payload["messages"]
)
if url_idx == None:
......@@ -878,10 +880,11 @@ class OpenAIChatCompletionForm(BaseModel):
@app.post("/v1/chat/completions")
@app.post("/v1/chat/completions/{url_idx}")
async def generate_openai_chat_completion(
form_data: OpenAIChatCompletionForm,
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
form_data = OpenAIChatCompletionForm(**form_data)
payload = {
**form_data.model_dump(exclude_none=True),
......@@ -913,22 +916,35 @@ async def generate_openai_chat_completion(
else None
)
if model_info.params.get("system", None):
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
"content": system,
},
)
......@@ -1094,17 +1110,13 @@ async def download_file_stream(
raise "Ollama: Could not create blob, Please try again."
# def number_generator():
# for i in range(1, 101):
# yield f"data: {i}\n"
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@app.post("/models/download")
@app.post("/models/download/{url_idx}")
async def download_model(
form_data: UrlForm,
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
......@@ -1133,7 +1145,11 @@ async def download_model(
@app.post("/models/upload")
@app.post("/models/upload/{url_idx}")
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
def upload_model(
file: UploadFile = File(...),
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
if url_idx == None:
url_idx = 0
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
......@@ -1196,137 +1212,3 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.config.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
# async def file_upload_generator(file):
# print(file)
# try:
# async with aiofiles.open(file_location, "wb") as f:
# completed_size = 0
# while True:
# chunk = await file.read(1024*1024)
# if not chunk:
# break
# await f.write(chunk)
# completed_size += len(chunk)
# progress = (completed_size / total_size) * 100
# print(progress)
# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
# except Exception as e:
# print(e)
# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
# finally:
# await file.close()
# print("done")
# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
# return StreamingResponse(
# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
# )
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user)
):
url = app.state.config.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}"
body = await request.body()
headers = dict(request.headers)
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
headers.pop("host", None)
headers.pop("authorization", None)
headers.pop("origin", None)
headers.pop("referer", None)
r = None
def get_request():
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if path == "generate":
data = json.loads(body.decode("utf-8"))
if data.get("stream", True):
yield json.dumps({"id": request_id, "done": False}) + "\n"
elif path == "chat":
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# r.close()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
......@@ -16,10 +16,12 @@ from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
get_current_user,
get_verified_user,
get_verified_user,
get_admin_user,
)
from utils.task import prompt_template
from config import (
SRC_LOG_LEVELS,
ENABLE_OPENAI_API,
......@@ -294,7 +296,7 @@ async def get_all_models(raw: bool = False):
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
if url_idx == None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
......@@ -392,22 +394,34 @@ async def generate_chat_completion(
else None
)
if model_info.params.get("system", None):
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
"content": system,
},
)
......@@ -418,7 +432,12 @@ async def generate_chat_completion(
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id}
payload["user"] = {
"name": user.name,
"id": user.id,
"email": user.email,
"role": user.role,
}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
......
......@@ -55,6 +55,9 @@ from apps.webui.models.documents import (
DocumentForm,
DocumentResponse,
)
from apps.webui.models.files import (
Files,
)
from apps.rag.utils import (
get_model_path,
......@@ -74,6 +77,7 @@ from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo
from apps.rag.search.tavily import search_tavily
from apps.rag.search.jina_search import search_jina
from utils.misc import (
calculate_sha256,
......@@ -81,7 +85,7 @@ from utils.misc import (
sanitize_filename,
extract_folders_after_data_docs,
)
from utils.utils import get_current_user, get_admin_user
from utils.utils import get_verified_user, get_admin_user
from config import (
AppConfig,
......@@ -112,6 +116,7 @@ from config import (
YOUTUBE_LOADER_LANGUAGE,
ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
......@@ -165,6 +170,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
......@@ -523,7 +529,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
async def get_rag_template(user=Depends(get_verified_user)):
return {
"status": True,
"template": app.state.config.RAG_TEMPLATE,
......@@ -580,7 +586,7 @@ class QueryDocForm(BaseModel):
@app.post("/query/doc")
def query_doc_handler(
form_data: QueryDocForm,
user=Depends(get_current_user),
user=Depends(get_verified_user),
):
try:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
......@@ -620,7 +626,7 @@ class QueryCollectionsForm(BaseModel):
@app.post("/query/collection")
def query_collection_handler(
form_data: QueryCollectionsForm,
user=Depends(get_current_user),
user=Depends(get_verified_user),
):
try:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
......@@ -651,7 +657,7 @@ def query_collection_handler(
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
def store_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)):
try:
loader = YoutubeLoader.from_youtube_url(
form_data.url,
......@@ -680,7 +686,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
@app.post("/web")
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
def store_web(form_data: UrlForm, user=Depends(get_verified_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = get_web_loader(
......@@ -775,6 +781,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SEARXNG_QUERY_URL,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
......@@ -788,6 +795,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.GOOGLE_PSE_ENGINE_ID,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception(
......@@ -799,6 +807,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.BRAVE_SEARCH_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
......@@ -808,6 +817,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPSTACK_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
https_enabled=app.state.config.SERPSTACK_HTTPS,
)
else:
......@@ -818,6 +828,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPER_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No SERPER_API_KEY found in environment variables")
......@@ -827,11 +838,16 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPLY_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
return search_duckduckgo(
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "tavily":
if app.state.config.TAVILY_API_KEY:
return search_tavily(
......@@ -841,12 +857,14 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
elif engine == "jina":
return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
else:
raise Exception("No search engine API key found in environment variables")
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
try:
logging.info(
f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
......@@ -1066,7 +1084,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
def store_doc(
collection_name: Optional[str] = Form(None),
file: UploadFile = File(...),
user=Depends(get_current_user),
user=Depends(get_verified_user),
):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
......@@ -1119,6 +1137,60 @@ def store_doc(
)
class ProcessDocForm(BaseModel):
file_id: str
collection_name: Optional[str] = None
@app.post("/process/doc")
def process_doc(
form_data: ProcessDocForm,
user=Depends(get_verified_user),
):
try:
file = Files.get_file_by_id(form_data.file_id)
file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")
f = open(file_path, "rb")
collection_name = form_data.collection_name
if collection_name == None:
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(
file.filename, file.meta.get("content_type"), file_path
)
data = loader.load()
try:
result = store_data_in_vector_db(data, collection_name)
if result:
return {
"status": True,
"collection_name": collection_name,
"known_type": known_type,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e,
)
except Exception as e:
log.exception(e)
if "No pandoc was found" in str(e):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
class TextRAGForm(BaseModel):
name: str
content: str
......@@ -1128,7 +1200,7 @@ class TextRAGForm(BaseModel):
@app.post("/text")
def store_text(
form_data: TextRAGForm,
user=Depends(get_current_user),
user=Depends(get_verified_user),
):
collection_name = form_data.collection_name
......
import logging
from typing import List, Optional
import requests
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
def search_brave(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
Args:
......@@ -29,6 +31,9 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
json_response = response.json()
results = json_response.get("web", {}).get("results", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
......
import logging
from apps.rag.search.main import SearchResult
from typing import List, Optional
from apps.rag.search.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS
......@@ -8,7 +8,9 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
def search_duckduckgo(
query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args:
......@@ -41,6 +43,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
snippet=result.get("body"),
)
)
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
# Return the list of search results
return results
import json
import logging
from typing import List, Optional
import requests
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
......@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse(
api_key: str, search_engine_id: str, query: str, count: int
api_key: str,
search_engine_id: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
......@@ -35,6 +39,8 @@ def search_google_pse(
json_response = response.json()
results = json_response.get("items", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
......
import logging
import requests
from yarl import URL
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_jina(query: str, count: int) -> list[SearchResult]:
"""
Search using Jina's Search API and return the results as a list of SearchResult objects.
Args:
query (str): The query to search for
count (int): The number of results to return
Returns:
List[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
headers = {
"Accept": "application/json",
}
url = str(URL(jina_search_endpoint + query))
response = requests.get(url, headers=headers)
response.raise_for_status()
data = response.json()
results = []
for result in data["data"][:count]:
results.append(
SearchResult(
link=result["url"],
title=result.get("title"),
snippet=result.get("content"),
)
)
return results
from typing import Optional
from urllib.parse import urlparse
from pydantic import BaseModel
def get_filtered_results(results, filter_list):
if not filter_list:
return results
filtered_results = []
for result in results:
domain = urlparse(result["url"]).netloc
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
filtered_results.append(result)
return filtered_results
class SearchResult(BaseModel):
link: str
title: Optional[str]
......
import logging
import requests
from typing import List
from typing import List, Optional
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
......@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searxng(
query_url: str, query: str, count: int, **kwargs
query_url: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
**kwargs,
) -> List[SearchResult]:
"""
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
......@@ -78,6 +82,8 @@ def search_searxng(
json_response = response.json()
results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content")
......
import json
import logging
from typing import List, Optional
import requests
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
def search_serper(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
......@@ -29,6 +31,8 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
results = sorted(
json_response.get("organic", []), key=lambda x: x.get("position", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
......
import json
import logging
from typing import List, Optional
import requests
from urllib.parse import urlencode
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
......@@ -19,6 +19,7 @@ def search_serply(
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US",
filter_list: Optional[List[str]] = None,
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
......@@ -57,7 +58,8 @@ def search_serply(
results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
......
import json
import logging
from typing import List, Optional
import requests
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
......@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpstack(
api_key: str, query: str, count: int, https_enabled: bool = True
api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
https_enabled: bool = True,
) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects.
......@@ -35,6 +39,8 @@ def search_serpstack(
results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
......
......@@ -237,7 +237,7 @@ def get_embedding_function(
def get_rag_context(
docs,
files,
messages,
embedding_function,
k,
......@@ -245,29 +245,29 @@ def get_rag_context(
r,
hybrid_search,
):
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
query = get_last_user_message(messages)
extracted_collections = []
relevant_contexts = []
for doc in docs:
for file in files:
context = None
collection_names = (
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
file["collection_names"]
if file["type"] == "collection"
else [file["collection_name"]]
)
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
log.debug(f"skipping {doc} as it has already been extracted")
log.debug(f"skipping {file} as it has already been extracted")
continue
try:
if doc["type"] == "text":
context = doc["content"]
if file["type"] == "text":
context = file["content"]
else:
if hybrid_search:
context = query_collection_with_hybrid_search(
......@@ -290,7 +290,7 @@ def get_rag_context(
context = None
if context:
relevant_contexts.append({**context, "source": doc})
relevant_contexts.append({**context, "source": file})
extracted_collections.extend(collection_names)
......
import os
import logging
import json
from peewee import *
from peewee_migrate import Router
from playhouse.db_url import connect
from apps.webui.internal.wrappers import register_connection
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
import os
import logging
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
......@@ -28,12 +29,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
else:
pass
DB = connect(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
# The `register_connection` function encapsulates the logic for setting up
# the database connection based on the connection string, while `connect`
# is a Peewee-specific method to manage the connection state and avoid errors
# when a connection is already open.
try:
DB = register_connection(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
raise
router = Router(
DB,
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
logger=log,
)
router.run()
DB.connect(reuse_if_open=True)
try:
DB.connect(reuse_if_open=True)
except OperationalError as e:
log.info(f"Failed to connect to database again due to: {e}")
pass
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class File(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
filename = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "file"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("file")
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Function(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
type = pw.TextField()
content = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "function"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("function")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment