Commit 4aab4609 authored by Jun Siang Cheah's avatar Jun Siang Cheah
Browse files

Merge remote-tracking branch 'upstream/dev' into feat/oauth

parents 4ff17acc a2ea6b1b
...@@ -11,8 +11,6 @@ on: ...@@ -11,8 +11,6 @@ on:
env: env:
REGISTRY: ghcr.io REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
FULL_IMAGE_NAME: ghcr.io/${{ github.repository }}
jobs: jobs:
build-main-image: build-main-image:
...@@ -28,6 +26,15 @@ jobs: ...@@ -28,6 +26,15 @@ jobs:
- linux/arm64 - linux/arm64
steps: steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Prepare - name: Prepare
run: | run: |
platform=${{ matrix.platform }} platform=${{ matrix.platform }}
...@@ -116,6 +123,15 @@ jobs: ...@@ -116,6 +123,15 @@ jobs:
- linux/arm64 - linux/arm64
steps: steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Prepare - name: Prepare
run: | run: |
platform=${{ matrix.platform }} platform=${{ matrix.platform }}
...@@ -207,6 +223,15 @@ jobs: ...@@ -207,6 +223,15 @@ jobs:
- linux/arm64 - linux/arm64
steps: steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Prepare - name: Prepare
run: | run: |
platform=${{ matrix.platform }} platform=${{ matrix.platform }}
...@@ -289,6 +314,15 @@ jobs: ...@@ -289,6 +314,15 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [ build-main-image ] needs: [ build-main-image ]
steps: steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests - name: Download digests
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
with: with:
...@@ -335,6 +369,15 @@ jobs: ...@@ -335,6 +369,15 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [ build-cuda-image ] needs: [ build-cuda-image ]
steps: steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests - name: Download digests
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
with: with:
...@@ -382,6 +425,15 @@ jobs: ...@@ -382,6 +425,15 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [ build-ollama-image ] needs: [ build-ollama-image ]
steps: steps:
# GitHub Packages requires the entire repository name to be in lowercase
# although the repository owner has a lowercase username, this prevents some people from running actions after forking
- name: Set repository and image name to lowercase
run: |
echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV}
echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV}
env:
IMAGE_NAME: '${{ github.repository }}'
- name: Download digests - name: Download digests
uses: actions/download-artifact@v4 uses: actions/download-artifact@v4
with: with:
......
...@@ -25,7 +25,7 @@ jobs: ...@@ -25,7 +25,7 @@ jobs:
--file docker-compose.api.yaml \ --file docker-compose.api.yaml \
--file docker-compose.a1111-test.yaml \ --file docker-compose.a1111-test.yaml \
up --detach --build up --detach --build
- name: Wait for Ollama to be up - name: Wait for Ollama to be up
timeout-minutes: 5 timeout-minutes: 5
run: | run: |
...@@ -43,7 +43,7 @@ jobs: ...@@ -43,7 +43,7 @@ jobs:
uses: cypress-io/github-action@v6 uses: cypress-io/github-action@v6
with: with:
browser: chrome browser: chrome
wait-on: 'http://localhost:3000' wait-on: "http://localhost:3000"
config: baseUrl=http://localhost:3000 config: baseUrl=http://localhost:3000
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
...@@ -82,18 +82,18 @@ jobs: ...@@ -82,18 +82,18 @@ jobs:
--health-retries 5 --health-retries 5
ports: ports:
- 5432:5432 - 5432:5432
# mysql: # mysql:
# image: mysql # image: mysql
# env: # env:
# MYSQL_ROOT_PASSWORD: mysql # MYSQL_ROOT_PASSWORD: mysql
# MYSQL_DATABASE: mysql # MYSQL_DATABASE: mysql
# options: >- # options: >-
# --health-cmd "mysqladmin ping -h localhost" # --health-cmd "mysqladmin ping -h localhost"
# --health-interval 10s # --health-interval 10s
# --health-timeout 5s # --health-timeout 5s
# --health-retries 5 # --health-retries 5
# ports: # ports:
# - 3306:3306 # - 3306:3306
steps: steps:
- name: Checkout Repository - name: Checkout Repository
uses: actions/checkout@v4 uses: actions/checkout@v4
...@@ -142,7 +142,6 @@ jobs: ...@@ -142,7 +142,6 @@ jobs:
echo "Server has stopped" echo "Server has stopped"
exit 1 exit 1
fi fi
- name: Test backend with Postgres - name: Test backend with Postgres
if: success() || steps.sqlite.conclusion == 'failure' if: success() || steps.sqlite.conclusion == 'failure'
...@@ -171,6 +170,25 @@ jobs: ...@@ -171,6 +170,25 @@ jobs:
exit 1 exit 1
fi fi
# Check that service will reconnect to postgres when connection will be closed
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has failed before postgres reconnect check"
exit 1
fi
echo "Terminating all connections to postgres..."
python -c "import os, psycopg2 as pg2; \
conn = pg2.connect(dsn=os.environ['DATABASE_URL'].replace('+pool', '')); \
cur = conn.cursor(); \
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
exit 1
fi
# - name: Test backend with MySQL # - name: Test backend with MySQL
# if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure' # if: success() || steps.sqlite.conclusion == 'failure' || steps.postgres.conclusion == 'failure'
# env: # env:
......
...@@ -5,6 +5,30 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,30 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.5] - 2024-06-16
### Added
- **📞 Enhanced Voice Call**: Text-to-speech (TTS) callback now operates in real-time for each sentence, reducing latency by not waiting for full completion.
- **👆 Tap to Interrupt**: During a call, you can now stop the assistant from speaking by simply tapping, instead of using voice. This resolves the issue of the speaker's voice being mistakenly registered as input.
- **😊 Emoji Call**: Toggle this feature on from the Settings > Interface, allowing LLMs to express emotions using emojis during voice calls for a more dynamic interaction.
- **🖱️ Quick Archive/Delete**: Use the Shift key + mouseover on the chat list to swiftly archive or delete items.
- **📝 Markdown Support in Model Descriptions**: You can now format model descriptions with markdown, enabling bold text, links, etc.
- **🧠 Editable Memories**: Adds the capability to modify memories.
- **📋 Admin Panel Sorting**: Introduces the ability to sort users/chats within the admin panel.
- **🌑 Dark Mode for Quick Selectors**: Dark mode now available for chat quick selectors (prompts, models, documents).
- **🔧 Advanced Parameters**: Adds 'num_keep' and 'num_batch' to advanced parameters for customization.
- **📅 Dynamic System Prompts**: New variables '{{CURRENT_DATETIME}}', '{{CURRENT_TIME}}', '{{USER_LOCATION}}' added for system prompts. Ensure '{{USER_LOCATION}}' is toggled on from Settings > Interface.
- **🌐 Tavily Web Search**: Includes Tavily as a web search provider option.
- **🖊️ Federated Auth Usernames**: Ability to set user names for federated authentication.
- **🔗 Auto Clean URLs**: When adding connection URLs, trailing slashes are now automatically removed.
- **🌐 Enhanced Translations**: Improved Chinese and Swedish translations.
### Fixed
- **⏳ AIOHTTP_CLIENT_TIMEOUT**: Introduced a new environment variable 'AIOHTTP_CLIENT_TIMEOUT' for requests to Ollama lasting longer than 5 minutes. Default is 300 seconds; set to blank ('') for no timeout.
- **❌ Message Delete Freeze**: Resolved an issue where message deletion would sometimes cause the web UI to freeze.
## [0.3.4] - 2024-06-12 ## [0.3.4] - 2024-06-12
### Fixed ### Fixed
......
...@@ -160,7 +160,7 @@ Check our Migration Guide available in our [Open WebUI Documentation](https://do ...@@ -160,7 +160,7 @@ Check our Migration Guide available in our [Open WebUI Documentation](https://do
If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this: If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this:
```bash ```bash
docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:dev docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --add-host=host.docker.internal:host-gateway --restart always ghcr.io/open-webui/open-webui:dev
``` ```
## What's Next? 🌟 ## What's Next? 🌟
......
...@@ -37,6 +37,10 @@ from config import ( ...@@ -37,6 +37,10 @@ from config import (
ENABLE_IMAGE_GENERATION, ENABLE_IMAGE_GENERATION,
AUTOMATIC1111_BASE_URL, AUTOMATIC1111_BASE_URL,
COMFYUI_BASE_URL, COMFYUI_BASE_URL,
COMFYUI_CFG_SCALE,
COMFYUI_SAMPLER,
COMFYUI_SCHEDULER,
COMFYUI_SD3,
IMAGES_OPENAI_API_BASE_URL, IMAGES_OPENAI_API_BASE_URL,
IMAGES_OPENAI_API_KEY, IMAGES_OPENAI_API_KEY,
IMAGE_GENERATION_MODEL, IMAGE_GENERATION_MODEL,
...@@ -78,6 +82,10 @@ app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL ...@@ -78,6 +82,10 @@ app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.IMAGE_SIZE = IMAGE_SIZE app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS app.state.config.IMAGE_STEPS = IMAGE_STEPS
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
@app.get("/config") @app.get("/config")
...@@ -457,6 +465,18 @@ def generate_image( ...@@ -457,6 +465,18 @@ def generate_image(
if form_data.negative_prompt is not None: if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt data["negative_prompt"] = form_data.negative_prompt
if app.state.config.COMFYUI_CFG_SCALE:
data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE
if app.state.config.COMFYUI_SAMPLER is not None:
data["sampler"] = app.state.config.COMFYUI_SAMPLER
if app.state.config.COMFYUI_SCHEDULER is not None:
data["scheduler"] = app.state.config.COMFYUI_SCHEDULER
if app.state.config.COMFYUI_SD3 is not None:
data["sd3"] = app.state.config.COMFYUI_SD3
data = ImageGenerationPayload(**data) data = ImageGenerationPayload(**data)
res = comfyui_generate_image( res = comfyui_generate_image(
......
...@@ -190,6 +190,10 @@ class ImageGenerationPayload(BaseModel): ...@@ -190,6 +190,10 @@ class ImageGenerationPayload(BaseModel):
width: int width: int
height: int height: int
n: int = 1 n: int = 1
cfg_scale: Optional[float] = None
sampler: Optional[str] = None
scheduler: Optional[str] = None
sd3: Optional[bool] = None
def comfyui_generate_image( def comfyui_generate_image(
...@@ -199,6 +203,18 @@ def comfyui_generate_image( ...@@ -199,6 +203,18 @@ def comfyui_generate_image(
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT) comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
if payload.cfg_scale:
comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale
if payload.sampler:
comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler
if payload.scheduler:
comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler
if payload.sd3:
comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage"
comfyui_prompt["4"]["inputs"]["ckpt_name"] = model comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
comfyui_prompt["5"]["inputs"]["width"] = payload.width comfyui_prompt["5"]["inputs"]["width"] = payload.width
......
...@@ -40,6 +40,7 @@ from utils.utils import ( ...@@ -40,6 +40,7 @@ from utils.utils import (
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
from utils.task import prompt_template
from config import ( from config import (
...@@ -52,7 +53,7 @@ from config import ( ...@@ -52,7 +53,7 @@ from config import (
UPLOAD_DIR, UPLOAD_DIR,
AppConfig, AppConfig,
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256, add_or_update_system_message
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
...@@ -199,9 +200,6 @@ def merge_models_lists(model_lists): ...@@ -199,9 +200,6 @@ def merge_models_lists(model_lists):
return list(merged_models.values()) return list(merged_models.values())
# user=Depends(get_current_user)
async def get_all_models(): async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
...@@ -817,24 +815,28 @@ async def generate_chat_completion( ...@@ -817,24 +815,28 @@ async def generate_chat_completion(
"num_thread", None "num_thread", None
) )
if model_info.params.get("system", None): system = model_info.params.get("system", None)
if system:
# Check if the payload already has a system message # Check if the payload already has a system message
# If not, add a system message to the payload # If not, add a system message to the payload
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
if payload.get("messages"): if payload.get("messages"):
for message in payload["messages"]: payload["messages"] = add_or_update_system_message(
if message.get("role") == "system": system, payload["messages"]
message["content"] = ( )
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if url_idx == None: if url_idx == None:
if ":" not in payload["model"]: if ":" not in payload["model"]:
...@@ -850,8 +852,7 @@ async def generate_chat_completion( ...@@ -850,8 +852,7 @@ async def generate_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
log.debug(payload)
print(payload)
return await post_streaming_url(f"{url}/api/chat", json.dumps(payload)) return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
...@@ -879,10 +880,11 @@ class OpenAIChatCompletionForm(BaseModel): ...@@ -879,10 +880,11 @@ class OpenAIChatCompletionForm(BaseModel):
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
@app.post("/v1/chat/completions/{url_idx}") @app.post("/v1/chat/completions/{url_idx}")
async def generate_openai_chat_completion( async def generate_openai_chat_completion(
form_data: OpenAIChatCompletionForm, form_data: dict,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
form_data = OpenAIChatCompletionForm(**form_data)
payload = { payload = {
**form_data.model_dump(exclude_none=True), **form_data.model_dump(exclude_none=True),
...@@ -914,22 +916,35 @@ async def generate_openai_chat_completion( ...@@ -914,22 +916,35 @@ async def generate_openai_chat_completion(
else None else None
) )
if model_info.params.get("system", None): system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message # Check if the payload already has a system message
# If not, add a system message to the payload # If not, add a system message to the payload
if payload.get("messages"): if payload.get("messages"):
for message in payload["messages"]: for message in payload["messages"]:
if message.get("role") == "system": if message.get("role") == "system":
message["content"] = ( message["content"] = system + message["content"]
model_info.params.get("system", None) + message["content"]
)
break break
else: else:
payload["messages"].insert( payload["messages"].insert(
0, 0,
{ {
"role": "system", "role": "system",
"content": model_info.params.get("system", None), "content": system,
}, },
) )
...@@ -1095,17 +1110,13 @@ async def download_file_stream( ...@@ -1095,17 +1110,13 @@ async def download_file_stream(
raise "Ollama: Could not create blob, Please try again." raise "Ollama: Could not create blob, Please try again."
# def number_generator():
# for i in range(1, 101):
# yield f"data: {i}\n"
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@app.post("/models/download") @app.post("/models/download")
@app.post("/models/download/{url_idx}") @app.post("/models/download/{url_idx}")
async def download_model( async def download_model(
form_data: UrlForm, form_data: UrlForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user),
): ):
allowed_hosts = ["https://huggingface.co/", "https://github.com/"] allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
...@@ -1134,7 +1145,11 @@ async def download_model( ...@@ -1134,7 +1145,11 @@ async def download_model(
@app.post("/models/upload") @app.post("/models/upload")
@app.post("/models/upload/{url_idx}") @app.post("/models/upload/{url_idx}")
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): def upload_model(
file: UploadFile = File(...),
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
if url_idx == None: if url_idx == None:
url_idx = 0 url_idx = 0
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
...@@ -1197,137 +1212,3 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): ...@@ -1197,137 +1212,3 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
yield f"data: {json.dumps(res)}\n\n" yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream") return StreamingResponse(file_process_stream(), media_type="text/event-stream")
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.config.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
# async def file_upload_generator(file):
# print(file)
# try:
# async with aiofiles.open(file_location, "wb") as f:
# completed_size = 0
# while True:
# chunk = await file.read(1024*1024)
# if not chunk:
# break
# await f.write(chunk)
# completed_size += len(chunk)
# progress = (completed_size / total_size) * 100
# print(progress)
# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
# except Exception as e:
# print(e)
# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
# finally:
# await file.close()
# print("done")
# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
# return StreamingResponse(
# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
# )
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user)
):
url = app.state.config.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}"
body = await request.body()
headers = dict(request.headers)
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
headers.pop("host", None)
headers.pop("authorization", None)
headers.pop("origin", None)
headers.pop("referer", None)
r = None
def get_request():
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if path == "generate":
data = json.loads(body.decode("utf-8"))
if data.get("stream", True):
yield json.dumps({"id": request_id, "done": False}) + "\n"
elif path == "chat":
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# r.close()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
...@@ -20,6 +20,8 @@ from utils.utils import ( ...@@ -20,6 +20,8 @@ from utils.utils import (
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
from utils.task import prompt_template
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
ENABLE_OPENAI_API, ENABLE_OPENAI_API,
...@@ -392,22 +394,34 @@ async def generate_chat_completion( ...@@ -392,22 +394,34 @@ async def generate_chat_completion(
else None else None
) )
if model_info.params.get("system", None): system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message # Check if the payload already has a system message
# If not, add a system message to the payload # If not, add a system message to the payload
if payload.get("messages"): if payload.get("messages"):
for message in payload["messages"]: for message in payload["messages"]:
if message.get("role") == "system": if message.get("role") == "system":
message["content"] = ( message["content"] = system + message["content"]
model_info.params.get("system", None) + message["content"]
)
break break
else: else:
payload["messages"].insert( payload["messages"].insert(
0, 0,
{ {
"role": "system", "role": "system",
"content": model_info.params.get("system", None), "content": system,
}, },
) )
...@@ -418,7 +432,12 @@ async def generate_chat_completion( ...@@ -418,7 +432,12 @@ async def generate_chat_completion(
idx = model["urlIdx"] idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"): if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id} payload["user"] = {
"name": user.name,
"id": user.id,
"email": user.email,
"role": user.role,
}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model # This is a workaround until OpenAI fixes the issue with this model
...@@ -430,13 +449,11 @@ async def generate_chat_completion( ...@@ -430,13 +449,11 @@ async def generate_chat_completion(
# Convert the modified body back to JSON # Convert the modified body back to JSON
payload = json.dumps(payload) payload = json.dumps(payload)
print(payload) log.debug(payload)
url = app.state.config.OPENAI_API_BASE_URLS[idx] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx] key = app.state.config.OPENAI_API_KEYS[idx]
print(payload)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {key}" headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
......
...@@ -55,6 +55,9 @@ from apps.webui.models.documents import ( ...@@ -55,6 +55,9 @@ from apps.webui.models.documents import (
DocumentForm, DocumentForm,
DocumentResponse, DocumentResponse,
) )
from apps.webui.models.files import (
Files,
)
from apps.rag.utils import ( from apps.rag.utils import (
get_model_path, get_model_path,
...@@ -112,6 +115,7 @@ from config import ( ...@@ -112,6 +115,7 @@ from config import (
YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_LANGUAGE,
ENABLE_RAG_WEB_SEARCH, ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE, RAG_WEB_SEARCH_ENGINE,
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
SEARXNG_QUERY_URL, SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY, GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID, GOOGLE_PSE_ENGINE_ID,
...@@ -165,6 +169,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None ...@@ -165,6 +169,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
...@@ -775,6 +780,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -775,6 +780,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SEARXNG_QUERY_URL, app.state.config.SEARXNG_QUERY_URL,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
else: else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables") raise Exception("No SEARXNG_QUERY_URL found in environment variables")
...@@ -788,6 +794,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -788,6 +794,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.GOOGLE_PSE_ENGINE_ID, app.state.config.GOOGLE_PSE_ENGINE_ID,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
else: else:
raise Exception( raise Exception(
...@@ -799,6 +806,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -799,6 +806,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.BRAVE_SEARCH_API_KEY, app.state.config.BRAVE_SEARCH_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
else: else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables") raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
...@@ -808,6 +816,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -808,6 +816,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPSTACK_API_KEY, app.state.config.SERPSTACK_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
https_enabled=app.state.config.SERPSTACK_HTTPS, https_enabled=app.state.config.SERPSTACK_HTTPS,
) )
else: else:
...@@ -818,6 +827,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -818,6 +827,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPER_API_KEY, app.state.config.SERPER_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
else: else:
raise Exception("No SERPER_API_KEY found in environment variables") raise Exception("No SERPER_API_KEY found in environment variables")
...@@ -827,11 +837,16 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -827,11 +837,16 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app.state.config.SERPLY_API_KEY, app.state.config.SERPLY_API_KEY,
query, query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
) )
else: else:
raise Exception("No SERPLY_API_KEY found in environment variables") raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo": elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) return search_duckduckgo(
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
)
elif engine == "tavily": elif engine == "tavily":
if app.state.config.TAVILY_API_KEY: if app.state.config.TAVILY_API_KEY:
return search_tavily( return search_tavily(
...@@ -1119,6 +1134,60 @@ def store_doc( ...@@ -1119,6 +1134,60 @@ def store_doc(
) )
class ProcessDocForm(BaseModel):
file_id: str
collection_name: Optional[str] = None
@app.post("/process/doc")
def process_doc(
form_data: ProcessDocForm,
user=Depends(get_current_user),
):
try:
file = Files.get_file_by_id(form_data.file_id)
file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")
f = open(file_path, "rb")
collection_name = form_data.collection_name
if collection_name == None:
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(
file.filename, file.meta.get("content_type"), file_path
)
data = loader.load()
try:
result = store_data_in_vector_db(data, collection_name)
if result:
return {
"status": True,
"collection_name": collection_name,
"known_type": known_type,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e,
)
except Exception as e:
log.exception(e)
if "No pandoc was found" in str(e):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
class TextRAGForm(BaseModel): class TextRAGForm(BaseModel):
name: str name: str
content: str content: str
......
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]: def search_brave(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects. """Search using Brave's Search API and return the results as a list of SearchResult objects.
Args: Args:
...@@ -29,6 +31,9 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]: ...@@ -29,6 +31,9 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
json_response = response.json() json_response = response.json()
results = json_response.get("web", {}).get("results", []) results = json_response.get("web", {}).get("results", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["url"], title=result.get("title"), snippet=result.get("snippet")
......
import logging import logging
from typing import List, Optional
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -8,7 +8,9 @@ log = logging.getLogger(__name__) ...@@ -8,7 +8,9 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(query: str, count: int) -> list[SearchResult]: def search_duckduckgo(
query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
""" """
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args: Args:
...@@ -41,6 +43,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]: ...@@ -41,6 +43,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
snippet=result.get("body"), snippet=result.get("body"),
) )
) )
print(results) if filter_list:
results = get_filtered_results(results, filter_list)
# Return the list of search results # Return the list of search results
return results return results
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse( def search_google_pse(
api_key: str, search_engine_id: str, query: str, count: int api_key: str,
search_engine_id: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects. """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
...@@ -35,6 +39,8 @@ def search_google_pse( ...@@ -35,6 +39,8 @@ def search_google_pse(
json_response = response.json() json_response = response.json()
results = json_response.get("items", []) results = json_response.get("items", [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],
......
from typing import Optional from typing import Optional
from urllib.parse import urlparse
from pydantic import BaseModel from pydantic import BaseModel
def get_filtered_results(results, filter_list):
if not filter_list:
return results
filtered_results = []
for result in results:
domain = urlparse(result["url"]).netloc
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
filtered_results.append(result)
return filtered_results
class SearchResult(BaseModel): class SearchResult(BaseModel):
link: str link: str
title: Optional[str] title: Optional[str]
......
import logging import logging
import requests import requests
from typing import List from typing import List, Optional
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searxng( def search_searxng(
query_url: str, query: str, count: int, **kwargs query_url: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
**kwargs,
) -> List[SearchResult]: ) -> List[SearchResult]:
""" """
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects. Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
...@@ -78,6 +82,8 @@ def search_searxng( ...@@ -78,6 +82,8 @@ def search_searxng(
json_response = response.json() json_response = response.json()
results = json_response.get("results", []) results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True) sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content") link=result["url"], title=result.get("title"), snippet=result.get("content")
......
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]: def search_serper(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
Args: Args:
...@@ -29,6 +31,8 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]: ...@@ -29,6 +31,8 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
results = sorted( results = sorted(
json_response.get("organic", []), key=lambda x: x.get("position", 0) json_response.get("organic", []), key=lambda x: x.get("position", 0)
) )
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],
......
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from urllib.parse import urlencode from urllib.parse import urlencode
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -19,6 +19,7 @@ def search_serply( ...@@ -19,6 +19,7 @@ def search_serply(
limit: int = 10, limit: int = 10,
device_type: str = "desktop", device_type: str = "desktop",
proxy_location: str = "US", proxy_location: str = "US",
filter_list: Optional[List[str]] = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
...@@ -57,7 +58,8 @@ def search_serply( ...@@ -57,7 +58,8 @@ def search_serply(
results = sorted( results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0) json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
) )
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["link"], link=result["link"],
......
import json import json
import logging import logging
from typing import List, Optional
import requests import requests
from apps.rag.search.main import SearchResult from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpstack( def search_serpstack(
api_key: str, query: str, count: int, https_enabled: bool = True api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
https_enabled: bool = True,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects. """Search using serpstack.com's and return the results as a list of SearchResult objects.
...@@ -35,6 +39,8 @@ def search_serpstack( ...@@ -35,6 +39,8 @@ def search_serpstack(
results = sorted( results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0) json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
) )
if filter_list:
results = get_filtered_results(results, filter_list)
return [ return [
SearchResult( SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet") link=result["url"], title=result.get("title"), snippet=result.get("snippet")
......
...@@ -237,7 +237,7 @@ def get_embedding_function( ...@@ -237,7 +237,7 @@ def get_embedding_function(
def get_rag_context( def get_rag_context(
docs, files,
messages, messages,
embedding_function, embedding_function,
k, k,
...@@ -245,29 +245,29 @@ def get_rag_context( ...@@ -245,29 +245,29 @@ def get_rag_context(
r, r,
hybrid_search, hybrid_search,
): ):
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}") log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
query = get_last_user_message(messages) query = get_last_user_message(messages)
extracted_collections = [] extracted_collections = []
relevant_contexts = [] relevant_contexts = []
for doc in docs: for file in files:
context = None context = None
collection_names = ( collection_names = (
doc["collection_names"] file["collection_names"]
if doc["type"] == "collection" if file["type"] == "collection"
else [doc["collection_name"]] else [file["collection_name"]]
) )
collection_names = set(collection_names).difference(extracted_collections) collection_names = set(collection_names).difference(extracted_collections)
if not collection_names: if not collection_names:
log.debug(f"skipping {doc} as it has already been extracted") log.debug(f"skipping {file} as it has already been extracted")
continue continue
try: try:
if doc["type"] == "text": if file["type"] == "text":
context = doc["content"] context = file["content"]
else: else:
if hybrid_search: if hybrid_search:
context = query_collection_with_hybrid_search( context = query_collection_with_hybrid_search(
...@@ -290,7 +290,7 @@ def get_rag_context( ...@@ -290,7 +290,7 @@ def get_rag_context(
context = None context = None
if context: if context:
relevant_contexts.append({**context, "source": doc}) relevant_contexts.append({**context, "source": file})
extracted_collections.extend(collection_names) extracted_collections.extend(collection_names)
......
import os
import logging
import json import json
from peewee import * from peewee import *
from peewee_migrate import Router from peewee_migrate import Router
from playhouse.db_url import connect
from apps.webui.internal.wrappers import register_connection
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
import os
import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"]) log.setLevel(SRC_LOG_LEVELS["DB"])
...@@ -28,12 +29,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"): ...@@ -28,12 +29,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
else: else:
pass pass
DB = connect(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.") # The `register_connection` function encapsulates the logic for setting up
# the database connection based on the connection string, while `connect`
# is a Peewee-specific method to manage the connection state and avoid errors
# when a connection is already open.
try:
DB = register_connection(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
raise
router = Router( router = Router(
DB, DB,
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations", migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
logger=log, logger=log,
) )
router.run() router.run()
DB.connect(reuse_if_open=True) try:
DB.connect(reuse_if_open=True)
except OperationalError as e:
log.info(f"Failed to connect to database again due to: {e}")
pass
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields info to the 'user' table
migrator.add_fields("user", info=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "info")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment