Commit 4aab4609 authored by Jun Siang Cheah's avatar Jun Siang Cheah
Browse files

Merge remote-tracking branch 'upstream/dev' into feat/oauth

parents 4ff17acc a2ea6b1b
......@@ -11,8 +11,6 @@ on:
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
FULL_IMAGE_NAME: ghcr.io/${{ github.repository }}
jobs:
build-main-image:
......@@ -28,6 +26,15 @@ jobs:
- linux/arm64
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Prepare
run: |
platform=${{ matrix.platform }}
......@@ -116,6 +123,15 @@ jobs:
- linux/arm64
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Prepare
run: |
platform=${{ matrix.platform }}
......@@ -207,6 +223,15 @@ jobs:
- linux/arm64
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Prepare
run: |
platform=${{ matrix.platform }}
......@@ -289,6 +314,15 @@ jobs:
runs-on: ubuntu-latest
needs: [ build-main-image ]
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests
uses: actions/download-artifact@v4
with:
......@@ -335,6 +369,15 @@ jobs:
runs-on: ubuntu-latest
needs: [ build-cuda-image ]
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests
uses: actions/download-artifact@v4
with:
......@@ -382,6 +425,15 @@ jobs:
runs-on: ubuntu-latest
needs: [ build-ollama-image ]
steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests
uses: actions/download-artifact@v4
with:
......
......@@ -25,7 +25,7 @@ jobs:
--file docker-compose.api.yaml \
--file docker-compose.a1111-test.yaml \
up --detach --build
- name: Wait for Ollama to be up
timeout-minutes: 5
run: |
......@@ -43,7 +43,7 @@ jobs:
uses: cypress-io/github-action@v6
with:
browser: chrome
wait-on: 'http://localhost:3000'
wait-on: "http://localhost:3000"
config: baseUrl=http://localhost:3000
- uses: actions/upload-artifact@v4
......@@ -82,18 +82,18 @@ jobs:
--health-retries 5
ports:
- 5432:5432
# mysql:
# image: mysql
# env:
# MYSQL_ROOT_PASSWORD: mysql
# MYSQL_DATABASE: mysql
# options: >-
# --health-cmd "mysqladmin ping -h localhost"
# --health-interval 10s
# --health-timeout 5s
# --health-retries 5
# ports:
# - 3306:3306
# mysql:
# image: mysql
# env:
# MYSQL_ROOT_PASSWORD: mysql
# MYSQL_DATABASE: mysql
# options: >-
# --health-cmd "mysqladmin ping -h localhost"
# --health-interval 10s
# --health-timeout 5s
# --health-retries 5
# ports:
# - 3306:3306
steps:
- name: Checkout Repository
uses: actions/checkout@v4
......@@ -142,7 +142,6 @@ jobs:
echo "Server has stopped"
exit 1
fi
- name: Test backend with Postgres
if: success() || steps.sqlite.conclusion == 'failure'
......@@ -171,6 +170,25 @@ jobs:
exit 1
fi
# Check that service will reconnect to postgres when connection will be closed
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has failed before postgres reconnect check"
exit 1
fi
echo "Terminating all connections to postgres..."
python -c "import os, psycopg2 as pg2; \
conn = pg2.connect(dsn=os.environ['DATABASE_URL'].replace('+pool', '')); \
cur = conn.cursor(); \
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
exit 1
fi
# - name: Test backend with MySQL
# if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
# env:
......
......@@ -5,6 +5,30 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.5] - 2024-06-16
### Added
- **📞 Enhanced Voice Call**: Text-to-speech (TTS) callback now operates in real-time for each sentence, reducing latency by not waiting for full completion.
- **👆 Tap to Interrupt**: During a call, you can now stop the assistant from speaking by simply tapping, instead of using voice. This resolves the issue of the speaker's voice being mistakenly registered as input.
- **😊 Emoji Call**: Toggle this feature on from the Settings > Interface, allowing LLMs to express emotions using emojis during voice calls for a more dynamic interaction.
- **🖱️ Quick Archive/Delete**: Use the Shift key + mouseover on the chat list to swiftly archive or delete items.
- **📝 Markdown Support in Model Descriptions**: You can now format model descriptions with markdown, enabling bold text, links, etc.
- **🧠 Editable Memories**: Adds the capability to modify memories.
- **📋 Admin Panel Sorting**: Introduces the ability to sort users/chats within the admin panel.
- **🌑 Dark Mode for Quick Selectors**: Dark mode now available for chat quick selectors (prompts, models, documents).
- **🔧 Advanced Parameters**: Adds 'num_keep' and 'num_batch' to advanced parameters for customization.
- **📅 Dynamic System Prompts**: New variables '{{CURRENT_DATETIME}}', '{{CURRENT_TIME}}', '{{USER_LOCATION}}' added for system prompts. Ensure '{{USER_LOCATION}}' is toggled on from Settings > Interface.
- **🌐 Tavily Web Search**: Includes Tavily as a web search provider option.
- **🖊️ Federated Auth Usernames**: Ability to set user names for federated authentication.
- **🔗 Auto Clean URLs**: When adding connection URLs, trailing slashes are now automatically removed.
- **🌐 Enhanced Translations**: Improved Chinese and Swedish translations.
### Fixed
- **⏳ AIOHTTP_CLIENT_TIMEOUT**: Introduced a new environment variable 'AIOHTTP_CLIENT_TIMEOUT' for requests to Ollama lasting longer than 5 minutes. Default is 300 seconds; set to blank ('') for no timeout.
- **❌ Message Delete Freeze**: Resolved an issue where message deletion would sometimes cause the web UI to freeze.
## [0.3.4] - 2024-06-12
### Fixed
......
......@@ -160,7 +160,7 @@ Check our Migration Guide available in our [Open WebUI Documentation](https://do
If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this:
```bash
docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:dev
docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --add-host=host.docker.internal:host-gateway --restart always ghcr.io/open-webui/open-webui:dev
```
## What's Next? 🌟
......
......@@ -37,6 +37,10 @@ from config import (
ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL,
COMFYUI_BASE_URL,
COMFYUI_CFG_SCALE,
COMFYUI_SAMPLER,
COMFYUI_SCHEDULER,
COMFYUI_SD3,
IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY,
IMAGE_GENERATION_MODEL,
......@@ -78,6 +82,10 @@ app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
@app.get("/config")
......@@ -457,6 +465,18 @@ def generate_image(
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
if app.state.config.COMFYUI_CFG_SCALE:
data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE
if app.state.config.COMFYUI_SAMPLER is not None:
data["sampler"] = app.state.config.COMFYUI_SAMPLER
if app.state.config.COMFYUI_SCHEDULER is not None:
data["scheduler"] = app.state.config.COMFYUI_SCHEDULER
if app.state.config.COMFYUI_SD3 is not None:
data["sd3"] = app.state.config.COMFYUI_SD3
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(
......
......@@ -190,6 +190,10 @@ class ImageGenerationPayload(BaseModel):
width: int
height: int
n: int = 1
cfg_scale: Optional[float] = None
sampler: Optional[str] = None
scheduler: Optional[str] = None
sd3: Optional[bool] = None
def comfyui_generate_image(
......@@ -199,6 +203,18 @@ def comfyui_generate_image(
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
if payload.cfg_scale:
comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale
if payload.sampler:
comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler
if payload.scheduler:
comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler
if payload.sd3:
comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage"
comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
comfyui_prompt["5"]["inputs"]["width"] = payload.width
......
......@@ -40,6 +40,7 @@ from utils.utils import (
get_verified_user,
get_admin_user,
)
from utils.task import prompt_template
from config import (
......@@ -52,7 +53,7 @@ from config import (
UPLOAD_DIR,
AppConfig,
)
from utils.misc import calculate_sha256
from utils.misc import calculate_sha256, add_or_update_system_message
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
......@@ -199,9 +200,6 @@ def merge_models_lists(model_lists):
return list(merged_models.values())
# user=Depends(get_current_user)
async def get_all_models():
log.info("get_all_models()")
......@@ -817,24 +815,28 @@ async def generate_chat_completion(
"num_thread", None
)
if model_info.params.get("system", None):
system = model_info.params.get("system", None)
if system:
# Check if the payload already has a system message
# If not, add a system message to the payload
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
payload["messages"] = add_or_update_system_message(
system, payload["messages"]
)
if url_idx == None:
if ":" not in payload["model"]:
......@@ -850,8 +852,7 @@ async def generate_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
print(payload)
log.debug(payload)
return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
......@@ -879,10 +880,11 @@ class OpenAIChatCompletionForm(BaseModel):
@app.post("/v1/chat/completions")
@app.post("/v1/chat/completions/{url_idx}")
async def generate_openai_chat_completion(
form_data: OpenAIChatCompletionForm,
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
form_data = OpenAIChatCompletionForm(**form_data)
payload = {
**form_data.model_dump(exclude_none=True),
......@@ -914,22 +916,35 @@ async def generate_openai_chat_completion(
else None
)
if model_info.params.get("system", None):
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
"content": system,
},
)
......@@ -1095,17 +1110,13 @@ async def download_file_stream(
raise "Ollama: Could not create blob, Please try again."
# def number_generator():
# for i in range(1, 101):
# yield f"data: {i}\n"
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@app.post("/models/download")
@app.post("/models/download/{url_idx}")
async def download_model(
form_data: UrlForm,
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
......@@ -1134,7 +1145,11 @@ async def download_model(
@app.post("/models/upload")
@app.post("/models/upload/{url_idx}")
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
def upload_model(
file: UploadFile = File(...),
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
if url_idx == None:
url_idx = 0
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
......@@ -1197,137 +1212,3 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.config.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
# async def file_upload_generator(file):
# print(file)
# try:
# async with aiofiles.open(file_location, "wb") as f:
# completed_size = 0
# while True:
# chunk = await file.read(1024*1024)
# if not chunk:
# break
# await f.write(chunk)
# completed_size += len(chunk)
# progress = (completed_size / total_size) * 100
# print(progress)
# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
# except Exception as e:
# print(e)
# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
# finally:
# await file.close()
# print("done")
# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
# return StreamingResponse(
# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
# )
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user)
):
url = app.state.config.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}"
body = await request.body()
headers = dict(request.headers)
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
headers.pop("host", None)
headers.pop("authorization", None)
headers.pop("origin", None)
headers.pop("referer", None)
r = None
def get_request():
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if path == "generate":
data = json.loads(body.decode("utf-8"))
if data.get("stream", True):
yield json.dumps({"id": request_id, "done": False}) + "\n"
elif path == "chat":
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# r.close()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
......@@ -20,6 +20,8 @@ from utils.utils import (
get_verified_user,
get_admin_user,
)
from utils.task import prompt_template
from config import (
SRC_LOG_LEVELS,
ENABLE_OPENAI_API,
......@@ -392,22 +394,34 @@ async def generate_chat_completion(
else None
)
if model_info.params.get("system", None):
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
"content": system,
},
)
......@@ -418,7 +432,12 @@ async def generate_chat_completion(
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id}
payload["user"] = {
"name": user.name,
"id": user.id,
"email": user.email,
"role": user.role,
}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
......@@ -430,13 +449,11 @@ async def generate_chat_completion(
# Convert the modified body back to JSON
payload = json.dumps(payload)
print(payload)
log.debug(payload)
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
print(payload)
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
......
......@@ -55,6 +55,9 @@ from apps.webui.models.documents import (
DocumentForm,
DocumentResponse,
)
from apps.webui.models.files import (
Files,
)
from apps.rag.utils import (
get_model_path,
......@@ -112,6 +115,7 @@ from config import (
YOUTUBE_LOADER_LANGUAGE,
ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
......@@ -165,6 +169,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
......@@ -775,6 +780,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SEARXNG_QUERY_URL,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
......@@ -788,6 +794,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.GOOGLE_PSE_ENGINE_ID,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception(
......@@ -799,6 +806,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.BRAVE_SEARCH_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
......@@ -808,6 +816,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPSTACK_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
https_enabled=app.state.config.SERPSTACK_HTTPS,
)
else:
......@@ -818,6 +827,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPER_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No SERPER_API_KEY found in environment variables")
......@@ -827,11 +837,16 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPLY_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
else:
raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
return search_duckduckgo(
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "tavily":
if app.state.config.TAVILY_API_KEY:
return search_tavily(
......@@ -1119,6 +1134,60 @@ def store_doc(
)
class ProcessDocForm(BaseModel):
file_id: str
collection_name: Optional[str] = None
@app.post("/process/doc")
def process_doc(
form_data: ProcessDocForm,
user=Depends(get_current_user),
):
try:
file = Files.get_file_by_id(form_data.file_id)
file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")
f = open(file_path, "rb")
collection_name = form_data.collection_name
if collection_name == None:
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(
file.filename, file.meta.get("content_type"), file_path
)
data = loader.load()
try:
result = store_data_in_vector_db(data, collection_name)
if result:
return {
"status": True,
"collection_name": collection_name,
"known_type": known_type,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e,
)
except Exception as e:
log.exception(e)
if "No pandoc was found" in str(e):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
class TextRAGForm(BaseModel):
name: str
content: str
......
import logging
from typing import List, Optional
import requests
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
def search_brave(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
Args:
......@@ -29,6 +31,9 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
json_response = response.json()
results = json_response.get("web", {}).get("results", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
......
import logging
from apps.rag.search.main import SearchResult
from typing import List, Optional
from apps.rag.search.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS
......@@ -8,7 +8,9 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
def search_duckduckgo(
query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args:
......@@ -41,6 +43,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
snippet=result.get("body"),
)
)
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
# Return the list of search results
return results
import json
import logging
from typing import List, Optional
import requests
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
......@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse(
api_key: str, search_engine_id: str, query: str, count: int
api_key: str,
search_engine_id: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
......@@ -35,6 +39,8 @@ def search_google_pse(
json_response = response.json()
results = json_response.get("items", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
......
from typing import Optional
from urllib.parse import urlparse
from pydantic import BaseModel
def get_filtered_results(results, filter_list):
if not filter_list:
return results
filtered_results = []
for result in results:
domain = urlparse(result["url"]).netloc
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
filtered_results.append(result)
return filtered_results
class SearchResult(BaseModel):
link: str
title: Optional[str]
......
import logging
import requests
from typing import List
from typing import List, Optional
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
......@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searxng(
query_url: str, query: str, count: int, **kwargs
query_url: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
**kwargs,
) -> List[SearchResult]:
"""
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
......@@ -78,6 +82,8 @@ def search_searxng(
json_response = response.json()
results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content")
......
import json
import logging
from typing import List, Optional
import requests
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
def search_serper(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
......@@ -29,6 +31,8 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
results = sorted(
json_response.get("organic", []), key=lambda x: x.get("position", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
......
import json
import logging
from typing import List, Optional
import requests
from urllib.parse import urlencode
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
......@@ -19,6 +19,7 @@ def search_serply(
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US",
filter_list: Optional[List[str]] = None,
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
......@@ -57,7 +58,8 @@ def search_serply(
results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
......
import json
import logging
from typing import List, Optional
import requests
from apps.rag.search.main import SearchResult
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
......@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpstack(
api_key: str, query: str, count: int, https_enabled: bool = True
api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
https_enabled: bool = True,
) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects.
......@@ -35,6 +39,8 @@ def search_serpstack(
results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
......
......@@ -237,7 +237,7 @@ def get_embedding_function(
def get_rag_context(
docs,
files,
messages,
embedding_function,
k,
......@@ -245,29 +245,29 @@ def get_rag_context(
r,
hybrid_search,
):
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
query = get_last_user_message(messages)
extracted_collections = []
relevant_contexts = []
for doc in docs:
for file in files:
context = None
collection_names = (
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
file["collection_names"]
if file["type"] == "collection"
else [file["collection_name"]]
)
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
log.debug(f"skipping {doc} as it has already been extracted")
log.debug(f"skipping {file} as it has already been extracted")
continue
try:
if doc["type"] == "text":
context = doc["content"]
if file["type"] == "text":
context = file["content"]
else:
if hybrid_search:
context = query_collection_with_hybrid_search(
......@@ -290,7 +290,7 @@ def get_rag_context(
context = None
if context:
relevant_contexts.append({**context, "source": doc})
relevant_contexts.append({**context, "source": file})
extracted_collections.extend(collection_names)
......
import os
import logging
import json
from peewee import *
from peewee_migrate import Router
from playhouse.db_url import connect
from apps.webui.internal.wrappers import register_connection
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
import os
import logging
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
......@@ -28,12 +29,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
else:
pass
DB = connect(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
# The `register_connection` function encapsulates the logic for setting up
# the database connection based on the connection string, while `connect`
# is a Peewee-specific method to manage the connection state and avoid errors
# when a connection is already open.
try:
DB = register_connection(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
raise
router = Router(
DB,
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
logger=log,
)
router.run()
DB.connect(reuse_if_open=True)
try:
DB.connect(reuse_if_open=True)
except OperationalError as e:
log.info(f"Failed to connect to database again due to: {e}")
pass
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields info to the 'user' table
migrator.add_fields("user", info=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "info")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment