Unverified Commit 13b0e7d6 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #4434 from open-webui/dev

0.3.13
parents 8d257ed5 c8badfe2
...@@ -15,6 +15,13 @@ jobs: ...@@ -15,6 +15,13 @@ jobs:
name: Run Cypress Integration Tests name: Run Cypress Integration Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Maximize build space
uses: AdityaGarg8/remove-unwanted-software@v4.1
with:
remove-android: 'true'
remove-haskell: 'true'
remove-codeql: 'true'
- name: Checkout Repository - name: Checkout Repository
uses: actions/checkout@v4 uses: actions/checkout@v4
......
...@@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.13] - 2024-08-14
### Added
- **🎨 Enhanced Markdown Rendering**: Significant improvements in rendering markdown, ensuring smooth and reliable display of LaTeX and Mermaid charts, enhancing user experience with more robust visual content.
- **🔄 Auto-Install Tools & Functions Python Dependencies**: For 'Tools' and 'Functions', Open WebUI now automatically install extra python requirements specified in the frontmatter, streamlining setup processes and customization.
- **🌀 OAuth Email Claim Customization**: Introduced an 'OAUTH_EMAIL_CLAIM' variable to allow customization of the default "email" claim within OAuth configurations, providing greater flexibility in authentication processes.
- **📶 Websocket Reconnection**: Enhanced reliability with the capability to automatically reconnect when a websocket is closed, ensuring consistent and stable communication.
- **🤳 Haptic Feedback on Support Devices**: Android devices now support haptic feedback for an immersive tactile experience during certain interactions.
### Fixed
- **🛠️ ComfyUI Performance Improvement**: Addressed an issue causing FastAPI to stall when ComfyUI image generation was active; now runs in a separate thread to prevent UI unresponsiveness.
- **🔀 Session Handling**: Fixed an issue mandating session_id on client-side to ensure smoother session management and transitions.
- **🖋️ Minor Bug Fixes and Format Corrections**: Various minor fixes including typo corrections, backend formatting improvements, and test amendments enhancing overall system stability and performance.
### Changed
- **🚀 Migration to SvelteKit 2**: Upgraded the underlying framework to SvelteKit version 2, offering enhanced speed, better code structure, and improved deployment capabilities.
- **🧹 General Cleanup and Refactoring**: Performed broad cleanup and refactoring across the platform, improving code efficiency and maintaining high standards of code health.
- **🚧 Integration Testing Improvements**: Modified how Cypress integration tests detect chat messages and updated sharing tests for better reliability and accuracy.
- **📁 Standardized '.safetensors' File Extension**: Renamed the '.sft' file extension to '.safetensors' for ComfyUI workflows, standardizing file formats across the platform.
### Removed
- **🗑️ Deprecated Frontend Functions**: Removed frontend functions that were migrated to backend to declutter the codebase and reduce redundancy.
## [0.3.12] - 2024-08-07 ## [0.3.12] - 2024-08-07
### Added ### Added
......
...@@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse ...@@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
from typing import List
import uuid import uuid
import requests import requests
import hashlib import hashlib
...@@ -244,7 +244,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -244,7 +244,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message']}" error_detail = f"External: {res['error']['message']}"
except: except Exception:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException( raise HTTPException(
...@@ -299,7 +299,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -299,7 +299,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message']}" error_detail = f"External: {res['error']['message']}"
except: except Exception:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException( raise HTTPException(
...@@ -353,7 +353,7 @@ def transcribe( ...@@ -353,7 +353,7 @@ def transcribe(
try: try:
model = WhisperModel(**whisper_kwargs) model = WhisperModel(**whisper_kwargs)
except: except Exception:
log.warning( log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False" "WhisperModel initialization failed, attempting download with local_files_only=False"
) )
...@@ -421,7 +421,7 @@ def transcribe( ...@@ -421,7 +421,7 @@ def transcribe(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message']}" error_detail = f"External: {res['error']['message']}"
except: except Exception:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException( raise HTTPException(
...@@ -438,7 +438,7 @@ def transcribe( ...@@ -438,7 +438,7 @@ def transcribe(
) )
def get_available_models() -> List[dict]: def get_available_models() -> list[dict]:
if app.state.config.TTS_ENGINE == "openai": if app.state.config.TTS_ENGINE == "openai":
return [{"id": "tts-1"}, {"id": "tts-1-hd"}] return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
elif app.state.config.TTS_ENGINE == "elevenlabs": elif app.state.config.TTS_ENGINE == "elevenlabs":
...@@ -466,7 +466,7 @@ async def get_models(user=Depends(get_verified_user)): ...@@ -466,7 +466,7 @@ async def get_models(user=Depends(get_verified_user)):
return {"models": get_available_models()} return {"models": get_available_models()}
def get_available_voices() -> List[dict]: def get_available_voices() -> list[dict]:
if app.state.config.TTS_ENGINE == "openai": if app.state.config.TTS_ENGINE == "openai":
return [ return [
{"name": "alloy", "id": "alloy"}, {"name": "alloy", "id": "alloy"},
......
...@@ -94,7 +94,7 @@ app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP ...@@ -94,7 +94,7 @@ app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
def get_automatic1111_api_auth(): def get_automatic1111_api_auth():
if app.state.config.AUTOMATIC1111_API_AUTH == None: if app.state.config.AUTOMATIC1111_API_AUTH is None:
return "" return ""
else: else:
auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8") auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
...@@ -145,28 +145,30 @@ async def get_engine_url(user=Depends(get_admin_user)): ...@@ -145,28 +145,30 @@ async def get_engine_url(user=Depends(get_admin_user)):
async def update_engine_url( async def update_engine_url(
form_data: EngineUrlUpdateForm, user=Depends(get_admin_user) form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.AUTOMATIC1111_BASE_URL == None: if form_data.AUTOMATIC1111_BASE_URL is None:
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else: else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/") url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
r.raise_for_status()
app.state.config.AUTOMATIC1111_BASE_URL = url app.state.config.AUTOMATIC1111_BASE_URL = url
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail="Invalid URL provided.") raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
if form_data.COMFYUI_BASE_URL == None: if form_data.COMFYUI_BASE_URL is None:
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else: else:
url = form_data.COMFYUI_BASE_URL.strip("/") url = form_data.COMFYUI_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
r.raise_for_status()
app.state.config.COMFYUI_BASE_URL = url app.state.config.COMFYUI_BASE_URL = url
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
if form_data.AUTOMATIC1111_API_AUTH == None: if form_data.AUTOMATIC1111_API_AUTH is None:
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
else: else:
app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
...@@ -514,7 +516,7 @@ async def image_generations( ...@@ -514,7 +516,7 @@ async def image_generations(
data = ImageGenerationPayload(**data) data = ImageGenerationPayload(**data)
res = comfyui_generate_image( res = await comfyui_generate_image(
app.state.config.MODEL, app.state.config.MODEL,
data, data,
user.id, user.id,
......
import asyncio
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json import json
import urllib.request import urllib.request
import urllib.parse import urllib.parse
...@@ -170,7 +170,7 @@ FLUX_DEFAULT_PROMPT = """ ...@@ -170,7 +170,7 @@ FLUX_DEFAULT_PROMPT = """
}, },
"10": { "10": {
"inputs": { "inputs": {
"vae_name": "ae.sft" "vae_name": "ae.safetensors"
}, },
"class_type": "VAELoader" "class_type": "VAELoader"
}, },
...@@ -184,7 +184,7 @@ FLUX_DEFAULT_PROMPT = """ ...@@ -184,7 +184,7 @@ FLUX_DEFAULT_PROMPT = """
}, },
"12": { "12": {
"inputs": { "inputs": {
"unet_name": "flux1-dev.sft", "unet_name": "flux1-dev.safetensors",
"weight_dtype": "default" "weight_dtype": "default"
}, },
"class_type": "UNETLoader" "class_type": "UNETLoader"
...@@ -328,7 +328,7 @@ class ImageGenerationPayload(BaseModel): ...@@ -328,7 +328,7 @@ class ImageGenerationPayload(BaseModel):
flux_fp8_clip: Optional[bool] = None flux_fp8_clip: Optional[bool] = None
def comfyui_generate_image( async def comfyui_generate_image(
model: str, payload: ImageGenerationPayload, client_id, base_url model: str, payload: ImageGenerationPayload, client_id, base_url
): ):
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://") ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
...@@ -397,7 +397,9 @@ def comfyui_generate_image( ...@@ -397,7 +397,9 @@ def comfyui_generate_image(
return None return None
try: try:
images = get_images(ws, comfyui_prompt, client_id, base_url) images = await asyncio.to_thread(
get_images, ws, comfyui_prompt, client_id, base_url
)
except Exception as e: except Exception as e:
log.exception(f"Error while receiving images: {e}") log.exception(f"Error while receiving images: {e}")
images = None images = None
......
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request, Request,
Response,
HTTPException, HTTPException,
Depends, Depends,
status,
UploadFile, UploadFile,
File, File,
BackgroundTasks,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
import os import os
import re import re
import copy
import random import random
import requests import requests
import json import json
import uuid
import aiohttp import aiohttp
import asyncio import asyncio
import logging import logging
import time import time
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional, List, Union from typing import Optional, Union
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
decode_token,
get_current_user,
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
from utils.task import prompt_template
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
...@@ -53,7 +42,12 @@ from config import ( ...@@ -53,7 +42,12 @@ from config import (
UPLOAD_DIR, UPLOAD_DIR,
AppConfig, AppConfig,
) )
from utils.misc import calculate_sha256, add_or_update_system_message from utils.misc import (
calculate_sha256,
apply_model_params_to_body_ollama,
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
...@@ -120,7 +114,7 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)): ...@@ -120,7 +114,7 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)):
class UrlUpdateForm(BaseModel): class UrlUpdateForm(BaseModel):
urls: List[str] urls: list[str]
@app.post("/urls/update") @app.post("/urls/update")
...@@ -183,7 +177,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True): ...@@ -183,7 +177,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True):
res = await r.json() res = await r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
...@@ -238,7 +232,7 @@ async def get_all_models(): ...@@ -238,7 +232,7 @@ async def get_all_models():
async def get_ollama_tags( async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_verified_user) url_idx: Optional[int] = None, user=Depends(get_verified_user)
): ):
if url_idx == None: if url_idx is None:
models = await get_all_models() models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
...@@ -269,7 +263,7 @@ async def get_ollama_tags( ...@@ -269,7 +263,7 @@ async def get_ollama_tags(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
...@@ -282,8 +276,7 @@ async def get_ollama_tags( ...@@ -282,8 +276,7 @@ async def get_ollama_tags(
@app.get("/api/version/{url_idx}") @app.get("/api/version/{url_idx}")
async def get_ollama_versions(url_idx: Optional[int] = None): async def get_ollama_versions(url_idx: Optional[int] = None):
if app.state.config.ENABLE_OLLAMA_API: if app.state.config.ENABLE_OLLAMA_API:
if url_idx == None: if url_idx is None:
# returns lowest version # returns lowest version
tasks = [ tasks = [
fetch_url(f"{url}/api/version") fetch_url(f"{url}/api/version")
...@@ -323,7 +316,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): ...@@ -323,7 +316,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
...@@ -346,8 +339,6 @@ async def pull_model( ...@@ -346,8 +339,6 @@ async def pull_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None
# Admin should be able to pull models from any source # Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True} payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
...@@ -367,7 +358,7 @@ async def push_model( ...@@ -367,7 +358,7 @@ async def push_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
if form_data.name in app.state.MODELS: if form_data.name in app.state.MODELS:
url_idx = app.state.MODELS[form_data.name]["urls"][0] url_idx = app.state.MODELS[form_data.name]["urls"][0]
else: else:
...@@ -417,7 +408,7 @@ async def copy_model( ...@@ -417,7 +408,7 @@ async def copy_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
if form_data.source in app.state.MODELS: if form_data.source in app.state.MODELS:
url_idx = app.state.MODELS[form_data.source]["urls"][0] url_idx = app.state.MODELS[form_data.source]["urls"][0]
else: else:
...@@ -428,13 +419,13 @@ async def copy_model( ...@@ -428,13 +419,13 @@ async def copy_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try:
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/copy", url=f"{url}/api/copy",
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
try:
r.raise_for_status() r.raise_for_status()
log.debug(f"r.text: {r.text}") log.debug(f"r.text: {r.text}")
...@@ -448,7 +439,7 @@ async def copy_model( ...@@ -448,7 +439,7 @@ async def copy_model(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
...@@ -464,7 +455,7 @@ async def delete_model( ...@@ -464,7 +455,7 @@ async def delete_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
if form_data.name in app.state.MODELS: if form_data.name in app.state.MODELS:
url_idx = app.state.MODELS[form_data.name]["urls"][0] url_idx = app.state.MODELS[form_data.name]["urls"][0]
else: else:
...@@ -476,12 +467,12 @@ async def delete_model( ...@@ -476,12 +467,12 @@ async def delete_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try:
r = requests.request( r = requests.request(
method="DELETE", method="DELETE",
url=f"{url}/api/delete", url=f"{url}/api/delete",
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
try:
r.raise_for_status() r.raise_for_status()
log.debug(f"r.text: {r.text}") log.debug(f"r.text: {r.text}")
...@@ -495,7 +486,7 @@ async def delete_model( ...@@ -495,7 +486,7 @@ async def delete_model(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
...@@ -516,12 +507,12 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ...@@ -516,12 +507,12 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try:
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/show", url=f"{url}/api/show",
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
try:
r.raise_for_status() r.raise_for_status()
return r.json() return r.json()
...@@ -533,7 +524,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ...@@ -533,7 +524,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
...@@ -556,7 +547,7 @@ async def generate_embeddings( ...@@ -556,7 +547,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx is None:
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
...@@ -573,12 +564,12 @@ async def generate_embeddings( ...@@ -573,12 +564,12 @@ async def generate_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try:
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/embeddings", url=f"{url}/api/embeddings",
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
try:
r.raise_for_status() r.raise_for_status()
return r.json() return r.json()
...@@ -590,7 +581,7 @@ async def generate_embeddings( ...@@ -590,7 +581,7 @@ async def generate_embeddings(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
...@@ -603,10 +594,9 @@ def generate_ollama_embeddings( ...@@ -603,10 +594,9 @@ def generate_ollama_embeddings(
form_data: GenerateEmbeddingsForm, form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
): ):
log.info(f"generate_ollama_embeddings {form_data}") log.info(f"generate_ollama_embeddings {form_data}")
if url_idx == None: if url_idx is None:
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
...@@ -623,12 +613,12 @@ def generate_ollama_embeddings( ...@@ -623,12 +613,12 @@ def generate_ollama_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try:
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/embeddings", url=f"{url}/api/embeddings",
data=form_data.model_dump_json(exclude_none=True).encode(), data=form_data.model_dump_json(exclude_none=True).encode(),
) )
try:
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
...@@ -638,7 +628,7 @@ def generate_ollama_embeddings( ...@@ -638,7 +628,7 @@ def generate_ollama_embeddings(
if "embedding" in data: if "embedding" in data:
return data["embedding"] return data["embedding"]
else: else:
raise "Something went wrong :/" raise Exception("Something went wrong :/")
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
...@@ -647,16 +637,16 @@ def generate_ollama_embeddings( ...@@ -647,16 +637,16 @@ def generate_ollama_embeddings(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise error_detail raise Exception(error_detail)
class GenerateCompletionForm(BaseModel): class GenerateCompletionForm(BaseModel):
model: str model: str
prompt: str prompt: str
images: Optional[List[str]] = None images: Optional[list[str]] = None
format: Optional[str] = None format: Optional[str] = None
options: Optional[dict] = None options: Optional[dict] = None
system: Optional[str] = None system: Optional[str] = None
...@@ -674,8 +664,7 @@ async def generate_completion( ...@@ -674,8 +664,7 @@ async def generate_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx is None:
if url_idx == None:
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
...@@ -700,12 +689,12 @@ async def generate_completion( ...@@ -700,12 +689,12 @@ async def generate_completion(
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: str role: str
content: str content: str
images: Optional[List[str]] = None images: Optional[list[str]] = None
class GenerateChatCompletionForm(BaseModel): class GenerateChatCompletionForm(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: list[ChatMessage]
format: Optional[str] = None format: Optional[str] = None
options: Optional[dict] = None options: Optional[dict] = None
template: Optional[str] = None template: Optional[str] = None
...@@ -713,6 +702,18 @@ class GenerateChatCompletionForm(BaseModel): ...@@ -713,6 +702,18 @@ class GenerateChatCompletionForm(BaseModel):
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
def get_ollama_url(url_idx: Optional[int], model: str):
if url_idx is None:
if model not in app.state.MODELS:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
)
url_idx = random.choice(app.state.MODELS[model]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
return url
@app.post("/api/chat") @app.post("/api/chat")
@app.post("/api/chat/{url_idx}") @app.post("/api/chat/{url_idx}")
async def generate_chat_completion( async def generate_chat_completion(
...@@ -720,12 +721,7 @@ async def generate_chat_completion( ...@@ -720,12 +721,7 @@ async def generate_chat_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
payload = { payload = {
**form_data.model_dump(exclude_none=True, exclude=["metadata"]), **form_data.model_dump(exclude_none=True, exclude=["metadata"]),
...@@ -740,185 +736,21 @@ async def generate_chat_completion( ...@@ -740,185 +736,21 @@ async def generate_chat_completion(
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump() params = model_info.params.model_dump()
if model_info.params: if params:
if payload.get("options") is None: if payload.get("options") is None:
payload["options"] = {} payload["options"] = {}
if ( payload["options"] = apply_model_params_to_body_ollama(
model_info.params.get("mirostat", None) params, payload["options"]
and payload["options"].get("mirostat") is None
):
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
if (
model_info.params.get("mirostat_eta", None)
and payload["options"].get("mirostat_eta") is None
):
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
if (
model_info.params.get("mirostat_tau", None)
and payload["options"].get("mirostat_tau") is None
):
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
if (
model_info.params.get("num_ctx", None)
and payload["options"].get("num_ctx") is None
):
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
if (
model_info.params.get("num_batch", None)
and payload["options"].get("num_batch") is None
):
payload["options"]["num_batch"] = model_info.params.get(
"num_batch", None
)
if (
model_info.params.get("num_keep", None)
and payload["options"].get("num_keep") is None
):
payload["options"]["num_keep"] = model_info.params.get("num_keep", None)
if (
model_info.params.get("repeat_last_n", None)
and payload["options"].get("repeat_last_n") is None
):
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
if (
model_info.params.get("frequency_penalty", None)
and payload["options"].get("frequency_penalty") is None
):
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
if (
model_info.params.get("temperature", None) is not None
and payload["options"].get("temperature") is None
):
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
if (
model_info.params.get("seed", None) is not None
and payload["options"].get("seed") is None
):
payload["options"]["seed"] = model_info.params.get("seed", None)
if (
model_info.params.get("stop", None)
and payload["options"].get("stop") is None
):
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if (
model_info.params.get("tfs_z", None)
and payload["options"].get("tfs_z") is None
):
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if (
model_info.params.get("max_tokens", None)
and payload["options"].get("max_tokens") is None
):
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
if (
model_info.params.get("top_k", None)
and payload["options"].get("top_k") is None
):
payload["options"]["top_k"] = model_info.params.get("top_k", None)
if (
model_info.params.get("top_p", None)
and payload["options"].get("top_p") is None
):
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if (
model_info.params.get("min_p", None)
and payload["options"].get("min_p") is None
):
payload["options"]["min_p"] = model_info.params.get("min_p", None)
if (
model_info.params.get("use_mmap", None)
and payload["options"].get("use_mmap") is None
):
payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
if (
model_info.params.get("use_mlock", None)
and payload["options"].get("use_mlock") is None
):
payload["options"]["use_mlock"] = model_info.params.get(
"use_mlock", None
)
if (
model_info.params.get("num_thread", None)
and payload["options"].get("num_thread") is None
):
payload["options"]["num_thread"] = model_info.params.get(
"num_thread", None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
if payload.get("messages"):
payload["messages"] = add_or_update_system_message(
system, payload["messages"]
) )
payload = apply_model_system_prompt_to_body(params, payload, user)
if url_idx == None:
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS: url = get_ollama_url(url_idx, payload["model"])
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
log.debug(payload) log.debug(payload)
...@@ -940,7 +772,7 @@ class OpenAIChatMessage(BaseModel): ...@@ -940,7 +772,7 @@ class OpenAIChatMessage(BaseModel):
class OpenAIChatCompletionForm(BaseModel): class OpenAIChatCompletionForm(BaseModel):
model: str model: str
messages: List[OpenAIChatMessage] messages: list[OpenAIChatMessage]
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
...@@ -952,83 +784,28 @@ async def generate_openai_chat_completion( ...@@ -952,83 +784,28 @@ async def generate_openai_chat_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
form_data = OpenAIChatCompletionForm(**form_data) completion_form = OpenAIChatCompletionForm(**form_data)
payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
if "metadata" in payload: if "metadata" in payload:
del payload["metadata"] del payload["metadata"]
model_id = form_data.model model_id = completion_form.model
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump() params = model_info.params.model_dump()
if model_info.params: if params:
payload["temperature"] = model_info.params.get("temperature", None) payload = apply_model_params_to_body_openai(params, payload)
payload["top_p"] = model_info.params.get("top_p", None) payload = apply_model_system_prompt_to_body(params, payload, user)
payload["max_tokens"] = model_info.params.get("max_tokens", None)
payload["frequency_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["seed"] = model_info.params.get("seed", None)
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
if url_idx == None:
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS: url = get_ollama_url(url_idx, payload["model"])
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
return await post_streaming_url( return await post_streaming_url(
...@@ -1044,7 +821,7 @@ async def get_openai_models( ...@@ -1044,7 +821,7 @@ async def get_openai_models(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx is None:
models = await get_all_models() models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
...@@ -1099,7 +876,7 @@ async def get_openai_models( ...@@ -1099,7 +876,7 @@ async def get_openai_models(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
...@@ -1125,7 +902,6 @@ def parse_huggingface_url(hf_url): ...@@ -1125,7 +902,6 @@ def parse_huggingface_url(hf_url):
path_components = parsed_url.path.split("/") path_components = parsed_url.path.split("/")
# Extract the desired output # Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1] model_file = path_components[-1]
return model_file return model_file
...@@ -1190,7 +966,6 @@ async def download_model( ...@@ -1190,7 +966,6 @@ async def download_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
allowed_hosts = ["https://huggingface.co/", "https://github.com/"] allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
if not any(form_data.url.startswith(host) for host in allowed_hosts): if not any(form_data.url.startswith(host) for host in allowed_hosts):
...@@ -1199,7 +974,7 @@ async def download_model( ...@@ -1199,7 +974,7 @@ async def download_model(
detail="Invalid file_url. Only URLs from allowed hosts are permitted.", detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
) )
if url_idx == None: if url_idx is None:
url_idx = 0 url_idx = 0
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
...@@ -1222,7 +997,7 @@ def upload_model( ...@@ -1222,7 +997,7 @@ def upload_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
url_idx = 0 url_idx = 0
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
......
...@@ -17,7 +17,10 @@ from utils.utils import ( ...@@ -17,7 +17,10 @@ from utils.utils import (
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body from utils.misc import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
...@@ -30,7 +33,7 @@ from config import ( ...@@ -30,7 +33,7 @@ from config import (
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
AppConfig, AppConfig,
) )
from typing import List, Optional, Literal, overload from typing import Optional, Literal, overload
import hashlib import hashlib
...@@ -86,11 +89,11 @@ async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user ...@@ -86,11 +89,11 @@ async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user
class UrlsUpdateForm(BaseModel): class UrlsUpdateForm(BaseModel):
urls: List[str] urls: list[str]
class KeysUpdateForm(BaseModel): class KeysUpdateForm(BaseModel):
keys: List[str] keys: list[str]
@app.get("/urls") @app.get("/urls")
...@@ -368,7 +371,7 @@ async def generate_chat_completion( ...@@ -368,7 +371,7 @@ async def generate_chat_completion(
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
params = model_info.params.model_dump() params = model_info.params.model_dump()
payload = apply_model_params_to_body(params, payload) payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user) payload = apply_model_system_prompt_to_body(params, payload, user)
model = app.state.MODELS[payload.get("model")] model = app.state.MODELS[payload.get("model")]
......
...@@ -13,7 +13,7 @@ import os, shutil, logging, re ...@@ -13,7 +13,7 @@ import os, shutil, logging, re
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Union, Sequence, Iterator, Any from typing import Union, Sequence, Iterator, Any
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
from langchain_core.documents import Document from langchain_core.documents import Document
...@@ -376,7 +376,7 @@ async def update_reranking_config( ...@@ -376,7 +376,7 @@ async def update_reranking_config(
try: try:
app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True)
return { return {
"status": True, "status": True,
...@@ -439,7 +439,7 @@ class ChunkParamUpdateForm(BaseModel): ...@@ -439,7 +439,7 @@ class ChunkParamUpdateForm(BaseModel):
class YoutubeLoaderConfig(BaseModel): class YoutubeLoaderConfig(BaseModel):
language: List[str] language: list[str]
translation: Optional[str] = None translation: Optional[str] = None
...@@ -642,7 +642,7 @@ def query_doc_handler( ...@@ -642,7 +642,7 @@ def query_doc_handler(
class QueryCollectionsForm(BaseModel): class QueryCollectionsForm(BaseModel):
collection_names: List[str] collection_names: list[str]
query: str query: str
k: Optional[int] = None k: Optional[int] = None
r: Optional[float] = None r: Optional[float] = None
...@@ -1021,7 +1021,7 @@ class TikaLoader: ...@@ -1021,7 +1021,7 @@ class TikaLoader:
self.file_path = file_path self.file_path = file_path
self.mime_type = mime_type self.mime_type = mime_type
def load(self) -> List[Document]: def load(self) -> list[Document]:
with open(self.file_path, "rb") as f: with open(self.file_path, "rb") as f:
data = f.read() data = f.read()
...@@ -1185,7 +1185,7 @@ def store_doc( ...@@ -1185,7 +1185,7 @@ def store_doc(
f.close() f.close()
f = open(file_path, "rb") f = open(file_path, "rb")
if collection_name == None: if collection_name is None:
collection_name = calculate_sha256(f)[:63] collection_name = calculate_sha256(f)[:63]
f.close() f.close()
...@@ -1238,7 +1238,7 @@ def process_doc( ...@@ -1238,7 +1238,7 @@ def process_doc(
f = open(file_path, "rb") f = open(file_path, "rb")
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name == None: if collection_name is None:
collection_name = calculate_sha256(f)[:63] collection_name = calculate_sha256(f)[:63]
f.close() f.close()
...@@ -1296,7 +1296,7 @@ def store_text( ...@@ -1296,7 +1296,7 @@ def store_text(
): ):
collection_name = form_data.collection_name collection_name = form_data.collection_name
if collection_name == None: if collection_name is None:
collection_name = calculate_sha256_string(form_data.content) collection_name = calculate_sha256_string(form_data.content)
result = store_text_in_vector_db( result = store_text_in_vector_db(
...@@ -1339,7 +1339,7 @@ def scan_docs_dir(user=Depends(get_admin_user)): ...@@ -1339,7 +1339,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
sanitized_filename = sanitize_filename(filename) sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename) doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None: if doc is None:
doc = Documents.insert_new_doc( doc = Documents.insert_new_doc(
user.id, user.id,
DocumentForm( DocumentForm(
......
import logging import logging
from typing import List, Optional from typing import Optional
import requests import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
...@@ -10,7 +10,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -10,7 +10,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave( def search_brave(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects. """Search using Brave's Search API and return the results as a list of SearchResult objects.
......
import logging import logging
from typing import List, Optional from typing import Optional
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -9,7 +9,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -9,7 +9,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo( def search_duckduckgo(
query: str, count: int, filter_list: Optional[List[str]] = None query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]: ) -> list[SearchResult]:
""" """
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
...@@ -18,7 +18,7 @@ def search_duckduckgo( ...@@ -18,7 +18,7 @@ def search_duckduckgo(
count (int): The number of results to return count (int): The number of results to return
Returns: Returns:
List[SearchResult]: A list of search results list[SearchResult]: A list of search results
""" """
# Use the DDGS context manager to create a DDGS object # Use the DDGS context manager to create a DDGS object
with DDGS() as ddgs: with DDGS() as ddgs:
......
import json import json
import logging import logging
from typing import List, Optional from typing import Optional
import requests import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
...@@ -15,7 +15,7 @@ def search_google_pse( ...@@ -15,7 +15,7 @@ def search_google_pse(
search_engine_id: str, search_engine_id: str,
query: str, query: str,
count: int, count: int,
filter_list: Optional[List[str]] = None, filter_list: Optional[list[str]] = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects. """Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
......
...@@ -17,7 +17,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]: ...@@ -17,7 +17,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]:
count (int): The number of results to return count (int): The number of results to return
Returns: Returns:
List[SearchResult]: A list of search results list[SearchResult]: A list of search results
""" """
jina_search_endpoint = "https://s.jina.ai/" jina_search_endpoint = "https://s.jina.ai/"
headers = { headers = {
......
import logging import logging
import requests import requests
from typing import List, Optional from typing import Optional
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -14,9 +14,9 @@ def search_searxng( ...@@ -14,9 +14,9 @@ def search_searxng(
query_url: str, query_url: str,
query: str, query: str,
count: int, count: int,
filter_list: Optional[List[str]] = None, filter_list: Optional[list[str]] = None,
**kwargs, **kwargs,
) -> List[SearchResult]: ) -> list[SearchResult]:
""" """
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects. Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
...@@ -31,10 +31,10 @@ def search_searxng( ...@@ -31,10 +31,10 @@ def search_searxng(
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string. language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate). safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''. time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided. categories: (Optional[list[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
Returns: Returns:
List[SearchResult]: A list of SearchResults sorted by relevance score in descending order. list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
Raise: Raise:
requests.exceptions.RequestException: If a request error occurs during the search process. requests.exceptions.RequestException: If a request error occurs during the search process.
......
import json import json
import logging import logging
from typing import List, Optional from typing import Optional
import requests import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
...@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper( def search_serper(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
......
import json import json
import logging import logging
from typing import List, Optional from typing import Optional
import requests import requests
from urllib.parse import urlencode from urllib.parse import urlencode
...@@ -19,7 +19,7 @@ def search_serply( ...@@ -19,7 +19,7 @@ def search_serply(
limit: int = 10, limit: int = 10,
device_type: str = "desktop", device_type: str = "desktop",
proxy_location: str = "US", proxy_location: str = "US",
filter_list: Optional[List[str]] = None, filter_list: Optional[list[str]] = None,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects. """Search using serper.dev's API and return the results as a list of SearchResult objects.
......
import json import json
import logging import logging
from typing import List, Optional from typing import Optional
import requests import requests
from apps.rag.search.main import SearchResult, get_filtered_results from apps.rag.search.main import SearchResult, get_filtered_results
...@@ -14,7 +14,7 @@ def search_serpstack( ...@@ -14,7 +14,7 @@ def search_serpstack(
api_key: str, api_key: str,
query: str, query: str,
count: int, count: int,
filter_list: Optional[List[str]] = None, filter_list: Optional[list[str]] = None,
https_enabled: bool = True, https_enabled: bool = True,
) -> list[SearchResult]: ) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects. """Search using serpstack.com's and return the results as a list of SearchResult objects.
......
...@@ -17,7 +17,7 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]: ...@@ -17,7 +17,7 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
query (str): The query to search for query (str): The query to search for
Returns: Returns:
List[SearchResult]: A list of search results list[SearchResult]: A list of search results
""" """
url = "https://api.tavily.com/search" url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key} data = {"query": query, "api_key": api_key}
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import logging import logging
import requests import requests
from typing import List, Union from typing import Union
from apps.ollama.main import ( from apps.ollama.main import (
generate_ollama_embeddings, generate_ollama_embeddings,
...@@ -142,7 +142,7 @@ def merge_and_sort_query_results(query_results, k, reverse=False): ...@@ -142,7 +142,7 @@ def merge_and_sort_query_results(query_results, k, reverse=False):
def query_collection( def query_collection(
collection_names: List[str], collection_names: list[str],
query: str, query: str,
embedding_function, embedding_function,
k: int, k: int,
...@@ -157,13 +157,13 @@ def query_collection( ...@@ -157,13 +157,13 @@ def query_collection(
embedding_function=embedding_function, embedding_function=embedding_function,
) )
results.append(result) results.append(result)
except: except Exception:
pass pass
return merge_and_sort_query_results(results, k=k) return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search( def query_collection_with_hybrid_search(
collection_names: List[str], collection_names: list[str],
query: str, query: str,
embedding_function, embedding_function,
k: int, k: int,
...@@ -182,7 +182,7 @@ def query_collection_with_hybrid_search( ...@@ -182,7 +182,7 @@ def query_collection_with_hybrid_search(
r=r, r=r,
) )
results.append(result) results.append(result)
except: except Exception:
pass pass
return merge_and_sort_query_results(results, k=k, reverse=True) return merge_and_sort_query_results(results, k=k, reverse=True)
...@@ -411,7 +411,7 @@ class ChromaRetriever(BaseRetriever): ...@@ -411,7 +411,7 @@ class ChromaRetriever(BaseRetriever):
query: str, query: str,
*, *,
run_manager: CallbackManagerForRetrieverRun, run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]: ) -> list[Document]:
query_embeddings = self.embedding_function(query) query_embeddings = self.embedding_function(query)
results = self.collection.query( results = self.collection.query(
......
...@@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id ...@@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id
from utils.misc import ( from utils.misc import (
openai_chat_chunk_message_template, openai_chat_chunk_message_template,
openai_chat_completion_message_template, openai_chat_completion_message_template,
apply_model_params_to_body, apply_model_params_to_body_openai,
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
) )
...@@ -46,6 +46,7 @@ from config import ( ...@@ -46,6 +46,7 @@ from config import (
AppConfig, AppConfig,
OAUTH_USERNAME_CLAIM, OAUTH_USERNAME_CLAIM,
OAUTH_PICTURE_CLAIM, OAUTH_PICTURE_CLAIM,
OAUTH_EMAIL_CLAIM,
) )
from apps.socket.main import get_event_call, get_event_emitter from apps.socket.main import get_event_call, get_event_emitter
...@@ -84,6 +85,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING ...@@ -84,6 +85,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app.state.MODELS = {} app.state.MODELS = {}
app.state.TOOLS = {} app.state.TOOLS = {}
...@@ -289,7 +291,7 @@ async def generate_function_chat_completion(form_data, user): ...@@ -289,7 +291,7 @@ async def generate_function_chat_completion(form_data, user):
form_data["model"] = model_info.base_model_id form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump() params = model_info.params.model_dump()
form_data = apply_model_params_to_body(params, form_data) form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user) form_data = apply_model_system_prompt_to_body(params, form_data, user)
pipe_id = get_pipe_id(form_data) pipe_id = get_pipe_id(form_data)
......
...@@ -140,7 +140,7 @@ class AuthsTable: ...@@ -140,7 +140,7 @@ class AuthsTable:
return None return None
else: else:
return None return None
except: except Exception:
return None return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
...@@ -152,7 +152,7 @@ class AuthsTable: ...@@ -152,7 +152,7 @@ class AuthsTable:
try: try:
user = Users.get_user_by_api_key(api_key) user = Users.get_user_by_api_key(api_key)
return user if user else None return user if user else None
except: except Exception:
return False return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
...@@ -163,7 +163,7 @@ class AuthsTable: ...@@ -163,7 +163,7 @@ class AuthsTable:
if auth: if auth:
user = Users.get_user_by_id(auth.id) user = Users.get_user_by_id(auth.id)
return user return user
except: except Exception:
return None return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool: def update_user_password_by_id(self, id: str, new_password: str) -> bool:
...@@ -174,7 +174,7 @@ class AuthsTable: ...@@ -174,7 +174,7 @@ class AuthsTable:
) )
db.commit() db.commit()
return True if result == 1 else False return True if result == 1 else False
except: except Exception:
return False return False
def update_email_by_id(self, id: str, email: str) -> bool: def update_email_by_id(self, id: str, email: str) -> bool:
...@@ -183,7 +183,7 @@ class AuthsTable: ...@@ -183,7 +183,7 @@ class AuthsTable:
result = db.query(Auth).filter_by(id=id).update({"email": email}) result = db.query(Auth).filter_by(id=id).update({"email": email})
db.commit() db.commit()
return True if result == 1 else False return True if result == 1 else False
except: except Exception:
return False return False
def delete_auth_by_id(self, id: str) -> bool: def delete_auth_by_id(self, id: str) -> bool:
...@@ -200,7 +200,7 @@ class AuthsTable: ...@@ -200,7 +200,7 @@ class AuthsTable:
return True return True
else: else:
return False return False
except: except Exception:
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment