"vscode:/vscode.git/clone" did not exist on "027a365a62ef971bb956b10a0ae3538ca3826b4d"
Unverified Commit 1eebb85f authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3323 from open-webui/dev

0.3.6
parents 9e4dd4b8 b224ba00
...@@ -25,7 +25,7 @@ jobs: ...@@ -25,7 +25,7 @@ jobs:
--file docker-compose.api.yaml \ --file docker-compose.api.yaml \
--file docker-compose.a1111-test.yaml \ --file docker-compose.a1111-test.yaml \
up --detach --build up --detach --build
- name: Wait for Ollama to be up - name: Wait for Ollama to be up
timeout-minutes: 5 timeout-minutes: 5
run: | run: |
...@@ -43,7 +43,7 @@ jobs: ...@@ -43,7 +43,7 @@ jobs:
uses: cypress-io/github-action@v6 uses: cypress-io/github-action@v6
with: with:
browser: chrome browser: chrome
wait-on: 'http://localhost:3000' wait-on: "http://localhost:3000"
config: baseUrl=http://localhost:3000 config: baseUrl=http://localhost:3000
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
...@@ -82,18 +82,18 @@ jobs: ...@@ -82,18 +82,18 @@ jobs:
--health-retries 5 --health-retries 5
ports: ports:
- 5432:5432 - 5432:5432
# mysql: # mysql:
# image: mysql # image: mysql
# env: # env:
# MYSQL_ROOT_PASSWORD: mysql # MYSQL_ROOT_PASSWORD: mysql
# MYSQL_DATABASE: mysql # MYSQL_DATABASE: mysql
# options: >- # options: >-
# --health-cmd "mysqladmin ping -h localhost" # --health-cmd "mysqladmin ping -h localhost"
# --health-interval 10s # --health-interval 10s
# --health-timeout 5s # --health-timeout 5s
# --health-retries 5 # --health-retries 5
# ports: # ports:
# - 3306:3306 # - 3306:3306
steps: steps:
- name: Checkout Repository - name: Checkout Repository
uses: actions/checkout@v4 uses: actions/checkout@v4
...@@ -142,7 +142,6 @@ jobs: ...@@ -142,7 +142,6 @@ jobs:
echo "Server has stopped" echo "Server has stopped"
exit 1 exit 1
fi fi
- name: Test backend with Postgres - name: Test backend with Postgres
if: success() || steps.sqlite.conclusion == 'failure' if: success() || steps.sqlite.conclusion == 'failure'
...@@ -171,6 +170,25 @@ jobs: ...@@ -171,6 +170,25 @@ jobs:
exit 1 exit 1
fi fi
# Check that service will reconnect to postgres when connection will be closed
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has failed before postgres reconnect check"
exit 1
fi
echo "Terminating all connections to postgres..."
python -c "import os, psycopg2 as pg2; \
conn = pg2.connect(dsn=os.environ['DATABASE_URL'].replace('+pool', '')); \
cur = conn.cursor(); \
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
exit 1
fi
# - name: Test backend with MySQL # - name: Test backend with MySQL
# if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure' # if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
# env: # env:
......
...@@ -5,6 +5,36 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,36 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.6] - 2024-06-27
### Added
- **✨ "Functions" Feature**: You can now utilize "Functions" like filters (middleware) and pipe (model) functions directly within the WebUI. While largely compatible with Pipelines, these native functions can be executed easily within Open WebUI. Example use cases for filter functions include usage monitoring, real-time translation, moderation, and automemory. For pipe functions, the scope ranges from Cohere and Anthropic integration directly within Open WebUI, enabling "Valves" for per-user OpenAI API key usage, and much more. If you encounter issues, SAFE_MODE has been introduced.
- **📁 Files API**: Compatible with OpenAI, this feature allows for custom Retrieval-Augmented Generation (RAG) in conjunction with the Filter Function. More examples will be shared on our community platform and official documentation website.
- **🛠️ Tool Enhancements**: Tools now support citations and "Valves". Documentation will be available shortly.
- **🔗 Iframe Support via Files API**: Enables rendering HTML directly into your chat interface using functions and tools. Use cases include playing games like DOOM and Snake, displaying a weather applet, and implementing Anthropic "artifacts"-like features. Stay tuned for updates on our community platform and documentation.
- **🔒 Experimental OAuth Support**: New experimental OAuth support. Check our documentation for more details.
- **🖼️ Custom Background Support**: Set a custom background from Settings > Interface to personalize your experience.
- **🔑 AUTOMATIC1111_API_AUTH Support**: Enhanced security for the AUTOMATIC1111 API.
- **🎨 Code Highlight Optimization**: Improved code highlighting features.
- **🎙️ Voice Interruption Feature**: Reintroduced and now toggleable from Settings > Interface.
- **💤 Wakelock API**: Now in use to prevent screen dimming during important tasks.
- **🔐 API Key Privacy**: All API keys are now hidden by default for better security.
- **🔍 New Web Search Provider**: Added jina_search as a new option.
- **🌐 Enhanced Internationalization (i18n)**: Improved Korean translation and updated Chinese and Ukrainian translations.
### Fixed
- **🔧 Conversation Mode Issue**: Fixed the issue where Conversation Mode remained active after being removed from settings.
- **📏 Scroll Button Obstruction**: Resolved the issue where the scrollToBottom button container obstructed clicks on buttons beneath it.
### Changed
- **⏲️ AIOHTTP_CLIENT_TIMEOUT**: Now set to `None` by default for improved configuration flexibility.
- **📞 Voice Call Enhancements**: Improved by skipping code blocks and expressions during calls.
- **🚫 Error Message Handling**: Disabled the continuation of operations with error messages.
- **🗂️ Playground Relocation**: Moved the Playground from the workspace to the user menu for better user experience.
## [0.3.5] - 2024-06-16 ## [0.3.5] - 2024-06-16
### Added ### Added
......
...@@ -325,7 +325,7 @@ def transcribe( ...@@ -325,7 +325,7 @@ def transcribe(
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"} headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
files = {"file": (filename, open(file_path, "rb"))} files = {"file": (filename, open(file_path, "rb"))}
data = {"model": "whisper-1"} data = {"model": app.state.config.STT_MODEL}
print(files, data) print(files, data)
......
import re import re
import requests import requests
import base64
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request, Request,
...@@ -15,7 +16,7 @@ from faster_whisper import WhisperModel ...@@ -15,7 +16,7 @@ from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
get_current_user, get_verified_user,
get_admin_user, get_admin_user,
) )
...@@ -36,6 +37,7 @@ from config import ( ...@@ -36,6 +37,7 @@ from config import (
IMAGE_GENERATION_ENGINE, IMAGE_GENERATION_ENGINE,
ENABLE_IMAGE_GENERATION, ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL, AUTOMATIC1111_BASE_URL,
AUTOMATIC1111_API_AUTH,
COMFYUI_BASE_URL, COMFYUI_BASE_URL,
COMFYUI_CFG_SCALE, COMFYUI_CFG_SCALE,
COMFYUI_SAMPLER, COMFYUI_SAMPLER,
...@@ -49,7 +51,6 @@ from config import ( ...@@ -49,7 +51,6 @@ from config import (
AppConfig, AppConfig,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"]) log.setLevel(SRC_LOG_LEVELS["IMAGES"])
...@@ -75,11 +76,10 @@ app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY ...@@ -75,11 +76,10 @@ app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.MODEL = IMAGE_GENERATION_MODEL app.state.config.MODEL = IMAGE_GENERATION_MODEL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.IMAGE_SIZE = IMAGE_SIZE app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS app.state.config.IMAGE_STEPS = IMAGE_STEPS
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
...@@ -88,6 +88,16 @@ app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER ...@@ -88,6 +88,16 @@ app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3 app.state.config.COMFYUI_SD3 = COMFYUI_SD3
def get_automatic1111_api_auth():
if app.state.config.AUTOMATIC1111_API_AUTH == None:
return ""
else:
auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
auth1111_base64_encoded_bytes = base64.b64encode(auth1111_byte_string)
auth1111_base64_encoded_string = auth1111_base64_encoded_bytes.decode("utf-8")
return f"Basic {auth1111_base64_encoded_string}"
@app.get("/config") @app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)): async def get_config(request: Request, user=Depends(get_admin_user)):
return { return {
...@@ -113,6 +123,7 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user ...@@ -113,6 +123,7 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
class EngineUrlUpdateForm(BaseModel): class EngineUrlUpdateForm(BaseModel):
AUTOMATIC1111_BASE_URL: Optional[str] = None AUTOMATIC1111_BASE_URL: Optional[str] = None
AUTOMATIC1111_API_AUTH: Optional[str] = None
COMFYUI_BASE_URL: Optional[str] = None COMFYUI_BASE_URL: Optional[str] = None
...@@ -120,6 +131,7 @@ class EngineUrlUpdateForm(BaseModel): ...@@ -120,6 +131,7 @@ class EngineUrlUpdateForm(BaseModel):
async def get_engine_url(user=Depends(get_admin_user)): async def get_engine_url(user=Depends(get_admin_user)):
return { return {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
} }
...@@ -128,7 +140,6 @@ async def get_engine_url(user=Depends(get_admin_user)): ...@@ -128,7 +140,6 @@ async def get_engine_url(user=Depends(get_admin_user)):
async def update_engine_url( async def update_engine_url(
form_data: EngineUrlUpdateForm, user=Depends(get_admin_user) form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.AUTOMATIC1111_BASE_URL == None: if form_data.AUTOMATIC1111_BASE_URL == None:
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else: else:
...@@ -150,8 +161,14 @@ async def update_engine_url( ...@@ -150,8 +161,14 @@ async def update_engine_url(
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.AUTOMATIC1111_API_AUTH == None:
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
else:
app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
return { return {
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"AUTOMATIC1111_API_AUTH": app.state.config.AUTOMATIC1111_API_AUTH,
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL, "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"status": True, "status": True,
} }
...@@ -241,7 +258,7 @@ async def update_image_size( ...@@ -241,7 +258,7 @@ async def update_image_size(
@app.get("/models") @app.get("/models")
def get_models(user=Depends(get_current_user)): def get_models(user=Depends(get_verified_user)):
try: try:
if app.state.config.ENGINE == "openai": if app.state.config.ENGINE == "openai":
return [ return [
...@@ -262,7 +279,8 @@ def get_models(user=Depends(get_current_user)): ...@@ -262,7 +279,8 @@ def get_models(user=Depends(get_current_user)):
else: else:
r = requests.get( r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models",
headers={"authorization": get_automatic1111_api_auth()},
) )
models = r.json() models = r.json()
return list( return list(
...@@ -289,7 +307,8 @@ async def get_default_model(user=Depends(get_admin_user)): ...@@ -289,7 +307,8 @@ async def get_default_model(user=Depends(get_admin_user)):
return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")} return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
else: else:
r = requests.get( r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options" url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": get_automatic1111_api_auth()},
) )
options = r.json() options = r.json()
return {"model": options["sd_model_checkpoint"]} return {"model": options["sd_model_checkpoint"]}
...@@ -307,8 +326,10 @@ def set_model_handler(model: str): ...@@ -307,8 +326,10 @@ def set_model_handler(model: str):
app.state.config.MODEL = model app.state.config.MODEL = model
return app.state.config.MODEL return app.state.config.MODEL
else: else:
api_auth = get_automatic1111_api_auth()
r = requests.get( r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options" url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
headers={"authorization": api_auth},
) )
options = r.json() options = r.json()
...@@ -317,6 +338,7 @@ def set_model_handler(model: str): ...@@ -317,6 +338,7 @@ def set_model_handler(model: str):
r = requests.post( r = requests.post(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options, json=options,
headers={"authorization": api_auth},
) )
return options return options
...@@ -325,7 +347,7 @@ def set_model_handler(model: str): ...@@ -325,7 +347,7 @@ def set_model_handler(model: str):
@app.post("/models/default/update") @app.post("/models/default/update")
def update_default_model( def update_default_model(
form_data: UpdateModelForm, form_data: UpdateModelForm,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
return set_model_handler(form_data.model) return set_model_handler(form_data.model)
...@@ -402,9 +424,8 @@ def save_url_image(url): ...@@ -402,9 +424,8 @@ def save_url_image(url):
@app.post("/generations") @app.post("/generations")
def generate_image( def generate_image(
form_data: GenerateImageForm, form_data: GenerateImageForm,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x"))) width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
r = None r = None
...@@ -519,6 +540,7 @@ def generate_image( ...@@ -519,6 +540,7 @@ def generate_image(
r = requests.post( r = requests.post(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data, json=data,
headers={"authorization": get_automatic1111_api_auth()},
) )
res = r.json() res = r.json()
......
...@@ -40,6 +40,7 @@ from utils.utils import ( ...@@ -40,6 +40,7 @@ from utils.utils import (
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
from utils.task import prompt_template
from config import ( from config import (
...@@ -52,7 +53,7 @@ from config import ( ...@@ -52,7 +53,7 @@ from config import (
UPLOAD_DIR, UPLOAD_DIR,
AppConfig, AppConfig,
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256, add_or_update_system_message
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
...@@ -199,9 +200,6 @@ def merge_models_lists(model_lists): ...@@ -199,9 +200,6 @@ def merge_models_lists(model_lists):
return list(merged_models.values()) return list(merged_models.values())
# user=Depends(get_current_user)
async def get_all_models(): async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
...@@ -817,24 +815,28 @@ async def generate_chat_completion( ...@@ -817,24 +815,28 @@ async def generate_chat_completion(
"num_thread", None "num_thread", None
) )
if model_info.params.get("system", None): system = model_info.params.get("system", None)
if system:
# Check if the payload already has a system message # Check if the payload already has a system message
# If not, add a system message to the payload # If not, add a system message to the payload
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
if payload.get("messages"): if payload.get("messages"):
for message in payload["messages"]: payload["messages"] = add_or_update_system_message(
if message.get("role") == "system": system, payload["messages"]
message["content"] = ( )
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if url_idx == None: if url_idx == None:
if ":" not in payload["model"]: if ":" not in payload["model"]:
...@@ -878,10 +880,11 @@ class OpenAIChatCompletionForm(BaseModel): ...@@ -878,10 +880,11 @@ class OpenAIChatCompletionForm(BaseModel):
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
@app.post("/v1/chat/completions/{url_idx}") @app.post("/v1/chat/completions/{url_idx}")
async def generate_openai_chat_completion( async def generate_openai_chat_completion(
form_data: OpenAIChatCompletionForm, form_data: dict,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
form_data = OpenAIChatCompletionForm(**form_data)
payload = { payload = {
**form_data.model_dump(exclude_none=True), **form_data.model_dump(exclude_none=True),
...@@ -913,22 +916,35 @@ async def generate_openai_chat_completion( ...@@ -913,22 +916,35 @@ async def generate_openai_chat_completion(
else None else None
) )
if model_info.params.get("system", None): system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message # Check if the payload already has a system message
# If not, add a system message to the payload # If not, add a system message to the payload
if payload.get("messages"): if payload.get("messages"):
for message in payload["messages"]: for message in payload["messages"]:
if message.get("role") == "system": if message.get("role") == "system":
message["content"] = ( message["content"] = system + message["content"]
model_info.params.get("system", None) + message["content"]
)
break break
else: else:
payload["messages"].insert( payload["messages"].insert(
0, 0,
{ {
"role": "system", "role": "system",
"content": model_info.params.get("system", None), "content": system,
}, },
) )
...@@ -1094,17 +1110,13 @@ async def download_file_stream( ...@@ -1094,17 +1110,13 @@ async def download_file_stream(
raise "Ollama: Could not create blob, Please try again." raise "Ollama: Could not create blob, Please try again."
# def number_generator():
# for i in range(1, 101):
# yield f"data: {i}\n"
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@app.post("/models/download") @app.post("/models/download")
@app.post("/models/download/{url_idx}") @app.post("/models/download/{url_idx}")
async def download_model( async def download_model(
form_data: UrlForm, form_data: UrlForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user),
): ):
allowed_hosts = ["https://huggingface.co/", "https://github.com/"] allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
...@@ -1133,7 +1145,11 @@ async def download_model( ...@@ -1133,7 +1145,11 @@ async def download_model(
@app.post("/models/upload") @app.post("/models/upload")
@app.post("/models/upload/{url_idx}") @app.post("/models/upload/{url_idx}")
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): def upload_model(
file: UploadFile = File(...),
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
if url_idx == None: if url_idx == None:
url_idx = 0 url_idx = 0
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
...@@ -1196,137 +1212,3 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): ...@@ -1196,137 +1212,3 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
yield f"data: {json.dumps(res)}\n\n" yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream") return StreamingResponse(file_process_stream(), media_type="text/event-stream")
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.config.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
# async def file_upload_generator(file):
# print(file)
# try:
# async with aiofiles.open(file_location, "wb") as f:
# completed_size = 0
# while True:
# chunk = await file.read(1024*1024)
# if not chunk:
# break
# await f.write(chunk)
# completed_size += len(chunk)
# progress = (completed_size / total_size) * 100
# print(progress)
# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
# except Exception as e:
# print(e)
# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
# finally:
# await file.close()
# print("done")
# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
# return StreamingResponse(
# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
# )
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user)
):
url = app.state.config.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}"
body = await request.body()
headers = dict(request.headers)
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
headers.pop("host", None)
headers.pop("authorization", None)
headers.pop("origin", None)
headers.pop("referer", None)
r = None
def get_request():
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if path == "generate":
data = json.loads(body.decode("utf-8"))
if data.get("stream", True):
yield json.dumps({"id": request_id, "done": False}) + "\n"
elif path == "chat":
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# r.close()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
...@@ -16,10 +16,12 @@ from apps.webui.models.users import Users ...@@ -16,10 +16,12 @@ from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
decode_token, decode_token,
get_current_user, get_verified_user,
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
from utils.task import prompt_template
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
ENABLE_OPENAI_API, ENABLE_OPENAI_API,
...@@ -294,7 +296,7 @@ async def get_all_models(raw: bool = False): ...@@ -294,7 +296,7 @@ async def get_all_models(raw: bool = False):
@app.get("/models") @app.get("/models")
@app.get("/models/{url_idx}") @app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
...@@ -392,22 +394,34 @@ async def generate_chat_completion( ...@@ -392,22 +394,34 @@ async def generate_chat_completion(
else None else None
) )
if model_info.params.get("system", None): system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message # Check if the payload already has a system message
# If not, add a system message to the payload # If not, add a system message to the payload
if payload.get("messages"): if payload.get("messages"):
for message in payload["messages"]: for message in payload["messages"]:
if message.get("role") == "system": if message.get("role") == "system":
message["content"] = ( message["content"] = system + message["content"]
model_info.params.get("system", None) + message["content"]
)
break break
else: else:
payload["messages"].insert( payload["messages"].insert(
0, 0,
{ {
"role": "system", "role": "system",
"content": model_info.params.get("system", None), "content": system,
}, },
) )
...@@ -418,7 +432,12 @@ async def generate_chat_completion( ...@@ -418,7 +432,12 @@ async def generate_chat_completion(
idx = model["urlIdx"] idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"): if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id} payload["user"] = {
"name": user.name,
"id": user.id,
"email": user.email,
"role": user.role,
}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model # This is a workaround until OpenAI fixes the issue with this model
......
...@@ -55,6 +55,9 @@ from apps.webui.models.documents import ( ...@@ -55,6 +55,9 @@ from apps.webui.models.documents import (
DocumentForm, DocumentForm,
DocumentResponse, DocumentResponse,
) )
from apps.webui.models.files import (
Files,
)
from apps.rag.utils import ( from apps.rag.utils import (
get_model_path, get_model_path,
...@@ -74,6 +77,7 @@ from apps.rag.search.serpstack import search_serpstack ...@@ -74,6 +77,7 @@ from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo from apps.rag.search.duckduckgo import search_duckduckgo
from apps.rag.search.tavily import search_tavily from apps.rag.search.tavily import search_tavily
from apps.rag.search.jina_search import search_jina
from utils.misc import ( from utils.misc import (
calculate_sha256, calculate_sha256,
...@@ -81,7 +85,7 @@ from utils.misc import ( ...@@ -81,7 +85,7 @@ from utils.misc import (
sanitize_filename, sanitize_filename,
extract_folders_after_data_docs, extract_folders_after_data_docs,
) )
from utils.utils import get_current_user, get_admin_user from utils.utils import get_verified_user, get_admin_user
from config import ( from config import (
AppConfig, AppConfig,
...@@ -112,6 +116,7 @@ from config import ( ...@@ -112,6 +116,7 @@ from config import (
YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_LANGUAGE,
ENABLE_RAG_WEB_SEARCH, ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE, RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
SEARXNG_QUERY_URL, SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY, GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID, GOOGLE_PSE_ENGINE_ID,
...@@ -165,6 +170,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None ...@@ -165,6 +170,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
...@@ -523,7 +529,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ ...@@ -523,7 +529,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
@app.get("/template") @app.get("/template")
async def get_rag_template(user=Depends(get_current_user)): async def get_rag_template(user=Depends(get_verified_user)):
return { return {
"status": True, "status": True,
"template": app.state.config.RAG_TEMPLATE, "template": app.state.config.RAG_TEMPLATE,
...@@ -580,7 +586,7 @@ class QueryDocForm(BaseModel): ...@@ -580,7 +586,7 @@ class QueryDocForm(BaseModel):
@app.post("/query/doc") @app.post("/query/doc")
def query_doc_handler( def query_doc_handler(
form_data: QueryDocForm, form_data: QueryDocForm,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
try: try:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH: if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
...@@ -620,7 +626,7 @@ class QueryCollectionsForm(BaseModel): ...@@ -620,7 +626,7 @@ class QueryCollectionsForm(BaseModel):
@app.post("/query/collection") @app.post("/query/collection")
def query_collection_handler( def query_collection_handler(
form_data: QueryCollectionsForm, form_data: QueryCollectionsForm,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
try: try:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH: if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
...@@ -651,7 +657,7 @@ def query_collection_handler( ...@@ -651,7 +657,7 @@ def query_collection_handler(
@app.post("/youtube") @app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): def store_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)):
try: try:
loader = YoutubeLoader.from_youtube_url( loader = YoutubeLoader.from_youtube_url(
form_data.url, form_data.url,
...@@ -680,7 +686,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): ...@@ -680,7 +686,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
@app.post("/web") @app.post("/web")
def store_web(form_data: UrlForm, user=Depends(get_current_user)): def store_web(form_data: UrlForm, user=Depends(get_verified_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = get_web_loader( loader = get_web_loader(
...@@ -775,6 +781,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -775,6 +781,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SEARXNG_QUERY_URL, app.state.config.SEARXNG_QUERY_URL,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
else: else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables") raise Exception("No SEARXNG_QUERY_URL found in environment variables")
...@@ -788,6 +795,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -788,6 +795,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.GOOGLE_PSE_ENGINE_ID, app.state.config.GOOGLE_PSE_ENGINE_ID,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
else: else:
raise Exception( raise Exception(
...@@ -799,6 +807,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -799,6 +807,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.BRAVE_SEARCH_API_KEY, app.state.config.BRAVE_SEARCH_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
else: else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
...@@ -808,6 +817,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -808,6 +817,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPSTACK_API_KEY, app.state.config.SERPSTACK_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
https_enabled=app.state.config.SERPSTACK_HTTPS, https_enabled=app.state.config.SERPSTACK_HTTPS,
) )
else: else:
...@@ -818,6 +828,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -818,6 +828,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPER_API_KEY, app.state.config.SERPER_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
else: else:
raise Exception("No SERPER_API_KEY found in environment variables") raise Exception("No SERPER_API_KEY found in environment variables")
...@@ -827,11 +838,16 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -827,11 +838,16 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPLY_API_KEY, app.state.config.SERPLY_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
else: else:
raise Exception("No SERPLY_API_KEY found in environment variables") raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo": elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) return search_duckduckgo(
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "tavily": elif engine == "tavily":
if app.state.config.TAVILY_API_KEY: if app.state.config.TAVILY_API_KEY:
return search_tavily( return search_tavily(
...@@ -841,12 +857,14 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -841,12 +857,14 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
) )
else: else:
raise Exception("No TAVILY_API_KEY found in environment variables") raise Exception("No TAVILY_API_KEY found in environment variables")
elif engine == "jina":
return search_jina(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
else: else:
raise Exception("No search engine API key found in environment variables") raise Exception("No search engine API key found in environment variables")
@app.post("/web/search") @app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)): def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
try: try:
logging.info( logging.info(
f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}" f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
...@@ -1066,7 +1084,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ...@@ -1066,7 +1084,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
def store_doc( def store_doc(
collection_name: Optional[str] = Form(None), collection_name: Optional[str] = Form(None),
file: UploadFile = File(...), file: UploadFile = File(...),
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
...@@ -1119,6 +1137,60 @@ def store_doc( ...@@ -1119,6 +1137,60 @@ def store_doc(
) )
class ProcessDocForm(BaseModel):
file_id: str
collection_name: Optional[str] = None
@app.post("/process/doc")
def process_doc(
form_data: ProcessDocForm,
user=Depends(get_verified_user),
):
try:
file = Files.get_file_by_id(form_data.file_id)
file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")
f = open(file_path, "rb")
collection_name = form_data.collection_name
if collection_name == None:
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(
file.filename, file.meta.get("content_type"), file_path
)
data = loader.load()
try:
result = store_data_in_vector_db(data, collection_name)
if result:
return {
"status": True,
"collection_name": collection_name,
"known_type": known_type,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e,
)
except Exception as e:
log.exception(e)
if "No pandoc was found" in str(e):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
class TextRAGForm(BaseModel): class TextRAGForm(BaseModel):
name: str name: str
content: str content: str
...@@ -1128,7 +1200,7 @@ class TextRAGForm(BaseModel): ...@@ -1128,7 +1200,7 @@ class TextRAGForm(BaseModel):
@app.post("/text") @app.post("/text")
def store_text( def store_text(
form_data: TextRAGForm, form_data: TextRAGForm,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
collection_name = form_data.collection_name collection_name = form_data.collection_name
......
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]: def search_brave(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects. """Search using Brave's Search API and return the results as a list of SearchResult objects.
Args: Args:
...@@ -29,6 +31,9 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]: ...@@ -29,6 +31,9 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
json_response = response.json() json_response = response.json()
results = json_response.get("web", {}).get("results", []) results = json_response.get("web", {}).get("results", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["url"], title=result.get("title"), snippet=result.get("snippet")
......
import logging import logging
from typing import List, Optional
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -8,7 +8,9 @@ log = logging.getLogger(__name__) ...@@ -8,7 +8,9 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(query: str, count: int) -> list[SearchResult]: def search_duckduckgo(
query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
""" """
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args: Args:
...@@ -41,6 +43,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]: ...@@ -41,6 +43,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
snippet=result.get("body"), snippet=result.get("body"),
) )
) )
print(results) if filter_list:
results = get_filtered_results(results, filter_list)
# Return the list of search results # Return the list of search results
return results return results
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse( def search_google_pse(
api_key: str, search_engine_id: str, query: str, count: int api_key: str,
search_engine_id: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects. """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
...@@ -35,6 +39,8 @@ def search_google_pse( ...@@ -35,6 +39,8 @@ def search_google_pse(
json_response = response.json() json_response = response.json()
results = json_response.get("items", []) results = json_response.get("items", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],
......
import logging
import requests
from yarl import URL
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_jina(query: str, count: int) -> list[SearchResult]:
"""
Search using Jina's Search API and return the results as a list of SearchResult objects.
Args:
query (str): The query to search for
count (int): The number of results to return
Returns:
List[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
headers = {
"Accept": "application/json",
}
url = str(URL(jina_search_endpoint + query))
response = requests.get(url, headers=headers)
response.raise_for_status()
data = response.json()
results = []
for result in data["data"][:count]:
results.append(
SearchResult(
link=result["url"],
title=result.get("title"),
snippet=result.get("content"),
)
)
return results
from typing import Optional from typing import Optional
from urllib.parse import urlparse
from pydantic import BaseModel from pydantic import BaseModel
def get_filtered_results(results, filter_list):
if not filter_list:
return results
filtered_results = []
for result in results:
domain = urlparse(result["url"]).netloc
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
filtered_results.append(result)
return filtered_results
class SearchResult(BaseModel): class SearchResult(BaseModel):
link: str link: str
title: Optional[str] title: Optional[str]
......
import logging import logging
import requests import requests
from typing import List from typing import List, Optional
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searxng( def search_searxng(
query_url: str, query: str, count: int, **kwargs query_url: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
**kwargs,
) -> List[SearchResult]: ) -> List[SearchResult]:
""" """
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects. Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
...@@ -78,6 +82,8 @@ def search_searxng( ...@@ -78,6 +82,8 @@ def search_searxng(
json_response = response.json() json_response = response.json()
results = json_response.get("results", []) results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True) sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content") link=result["url"], title=result.get("title"), snippet=result.get("content")
......
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]: def search_serper(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
Args: Args:
...@@ -29,6 +31,8 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]: ...@@ -29,6 +31,8 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
results = sorted( results = sorted(
json_response.get("organic", []), key=lambda x: x.get("position", 0) json_response.get("organic", []), key=lambda x: x.get("position", 0)
) )
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],
......
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from urllib.parse import urlencode from urllib.parse import urlencode
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -19,6 +19,7 @@ def search_serply( ...@@ -19,6 +19,7 @@ def search_serply(
limit: int = 10, limit: int = 10,
device_type: str = "desktop", device_type: str = "desktop",
proxy_location: str = "US", proxy_location: str = "US",
filter_list: Optional[List[str]] = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
...@@ -57,7 +58,8 @@ def search_serply( ...@@ -57,7 +58,8 @@ def search_serply(
results = sorted( results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0) json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
) )
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],
......
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpstack( def search_serpstack(
api_key: str, query: str, count: int, https_enabled: bool = True api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
https_enabled: bool = True,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects. """Search using serpstack.com's and return the results as a list of SearchResult objects.
...@@ -35,6 +39,8 @@ def search_serpstack( ...@@ -35,6 +39,8 @@ def search_serpstack(
results = sorted( results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0) json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
) )
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["url"], title=result.get("title"), snippet=result.get("snippet")
......
...@@ -237,7 +237,7 @@ def get_embedding_function( ...@@ -237,7 +237,7 @@ def get_embedding_function(
def get_rag_context( def get_rag_context(
docs, files,
messages, messages,
embedding_function, embedding_function,
k, k,
...@@ -245,29 +245,29 @@ def get_rag_context( ...@@ -245,29 +245,29 @@ def get_rag_context(
r, r,
hybrid_search, hybrid_search,
): ):
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}") log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
query = get_last_user_message(messages) query = get_last_user_message(messages)
extracted_collections = [] extracted_collections = []
relevant_contexts = [] relevant_contexts = []
for doc in docs: for file in files:
context = None context = None
collection_names = ( collection_names = (
doc["collection_names"] file["collection_names"]
if doc["type"] == "collection" if file["type"] == "collection"
else [doc["collection_name"]] else [file["collection_name"]]
) )
collection_names = set(collection_names).difference(extracted_collections) collection_names = set(collection_names).difference(extracted_collections)
if not collection_names: if not collection_names:
log.debug(f"skipping {doc} as it has already been extracted") log.debug(f"skipping {file} as it has already been extracted")
continue continue
try: try:
if doc["type"] == "text": if file["type"] == "text":
context = doc["content"] context = file["content"]
else: else:
if hybrid_search: if hybrid_search:
context = query_collection_with_hybrid_search( context = query_collection_with_hybrid_search(
...@@ -290,7 +290,7 @@ def get_rag_context( ...@@ -290,7 +290,7 @@ def get_rag_context(
context = None context = None
if context: if context:
relevant_contexts.append({**context, "source": doc}) relevant_contexts.append({**context, "source": file})
extracted_collections.extend(collection_names) extracted_collections.extend(collection_names)
......
import os
import logging
import json import json
from peewee import * from peewee import *
from peewee_migrate import Router from peewee_migrate import Router
from playhouse.db_url import connect
from apps.webui.internal.wrappers import register_connection
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
import os
import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"]) log.setLevel(SRC_LOG_LEVELS["DB"])
...@@ -28,12 +29,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"): ...@@ -28,12 +29,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
else: else:
pass pass
DB = connect(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.") # The `register_connection` function encapsulates the logic for setting up
# the database connection based on the connection string, while `connect`
# is a Peewee-specific method to manage the connection state and avoid errors
# when a connection is already open.
try:
DB = register_connection(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
raise
router = Router( router = Router(
DB, DB,
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations", migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
logger=log, logger=log,
) )
router.run() router.run()
DB.connect(reuse_if_open=True) try:
DB.connect(reuse_if_open=True)
except OperationalError as e:
log.info(f"Failed to connect to database again due to: {e}")
pass
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class File(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
filename = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "file"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("file")
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Function(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
type = pw.TextField()
content = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "function"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("function")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment