Unverified Commit 13b0e7d6 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #4434 from open-webui/dev

0.3.13
parents 8d257ed5 c8badfe2
......@@ -15,6 +15,13 @@ jobs:
name: Run Cypress Integration Tests
runs-on: ubuntu-latest
steps:
- name: Maximize build space
uses: AdityaGarg8/remove-unwanted-software@v4.1
with:
remove-android: 'true'
remove-haskell: 'true'
remove-codeql: 'true'
- name: Checkout Repository
uses: actions/checkout@v4
......
......@@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.13] - 2024-08-14
### Added
- **🎨 Enhanced Markdown Rendering**: Significant improvements in rendering markdown, ensuring smooth and reliable display of LaTeX and Mermaid charts, enhancing user experience with more robust visual content.
- **🔄 Auto-Install Tools & Functions Python Dependencies**: For 'Tools' and 'Functions', Open WebUI now automatically install extra python requirements specified in the frontmatter, streamlining setup processes and customization.
- **🌀 OAuth Email Claim Customization**: Introduced an 'OAUTH_EMAIL_CLAIM' variable to allow customization of the default "email" claim within OAuth configurations, providing greater flexibility in authentication processes.
- **📶 Websocket Reconnection**: Enhanced reliability with the capability to automatically reconnect when a websocket is closed, ensuring consistent and stable communication.
- **🤳 Haptic Feedback on Support Devices**: Android devices now support haptic feedback for an immersive tactile experience during certain interactions.
### Fixed
- **🛠️ ComfyUI Performance Improvement**: Addressed an issue causing FastAPI to stall when ComfyUI image generation was active; now runs in a separate thread to prevent UI unresponsiveness.
- **🔀 Session Handling**: Fixed an issue mandating session_id on client-side to ensure smoother session management and transitions.
- **🖋️ Minor Bug Fixes and Format Corrections**: Various minor fixes including typo corrections, backend formatting improvements, and test amendments enhancing overall system stability and performance.
### Changed
- **🚀 Migration to SvelteKit 2**: Upgraded the underlying framework to SvelteKit version 2, offering enhanced speed, better code structure, and improved deployment capabilities.
- **🧹 General Cleanup and Refactoring**: Performed broad cleanup and refactoring across the platform, improving code efficiency and maintaining high standards of code health.
- **🚧 Integration Testing Improvements**: Modified how Cypress integration tests detect chat messages and updated sharing tests for better reliability and accuracy.
- **📁 Standardized '.safetensors' File Extension**: Renamed the '.sft' file extension to '.safetensors' for ComfyUI workflows, standardizing file formats across the platform.
### Removed
- **🗑️ Deprecated Frontend Functions**: Removed frontend functions that were migrated to backend to declutter the codebase and reduce redundancy.
## [0.3.12] - 2024-08-07
### Added
......
......@@ -15,7 +15,7 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
import uuid
import requests
import hashlib
......@@ -244,7 +244,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
......@@ -299,7 +299,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
......@@ -353,7 +353,7 @@ def transcribe(
try:
model = WhisperModel(**whisper_kwargs)
except:
except Exception:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
......@@ -421,7 +421,7 @@ def transcribe(
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except:
except Exception:
error_detail = f"External: {e}"
raise HTTPException(
......@@ -438,7 +438,7 @@ def transcribe(
)
def get_available_models() -> List[dict]:
def get_available_models() -> list[dict]:
if app.state.config.TTS_ENGINE == "openai":
return [{"id": "tts-1"}, {"id": "tts-1-hd"}]
elif app.state.config.TTS_ENGINE == "elevenlabs":
......@@ -466,7 +466,7 @@ async def get_models(user=Depends(get_verified_user)):
return {"models": get_available_models()}
def get_available_voices() -> List[dict]:
def get_available_voices() -> list[dict]:
if app.state.config.TTS_ENGINE == "openai":
return [
{"name": "alloy", "id": "alloy"},
......
......@@ -94,7 +94,7 @@ app.state.config.COMFYUI_FLUX_FP8_CLIP = COMFYUI_FLUX_FP8_CLIP
def get_automatic1111_api_auth():
if app.state.config.AUTOMATIC1111_API_AUTH == None:
if app.state.config.AUTOMATIC1111_API_AUTH is None:
return ""
else:
auth1111_byte_string = app.state.config.AUTOMATIC1111_API_AUTH.encode("utf-8")
......@@ -145,28 +145,30 @@ async def get_engine_url(user=Depends(get_admin_user)):
async def update_engine_url(
form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
):
if form_data.AUTOMATIC1111_BASE_URL == None:
if form_data.AUTOMATIC1111_BASE_URL is None:
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try:
r = requests.head(url)
r.raise_for_status()
app.state.config.AUTOMATIC1111_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail="Invalid URL provided.")
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
if form_data.COMFYUI_BASE_URL == None:
if form_data.COMFYUI_BASE_URL is None:
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else:
url = form_data.COMFYUI_BASE_URL.strip("/")
try:
r = requests.head(url)
r.raise_for_status()
app.state.config.COMFYUI_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
if form_data.AUTOMATIC1111_API_AUTH == None:
if form_data.AUTOMATIC1111_API_AUTH is None:
app.state.config.AUTOMATIC1111_API_AUTH = AUTOMATIC1111_API_AUTH
else:
app.state.config.AUTOMATIC1111_API_AUTH = form_data.AUTOMATIC1111_API_AUTH
......@@ -514,7 +516,7 @@ async def image_generations(
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(
res = await comfyui_generate_image(
app.state.config.MODEL,
data,
user.id,
......
import asyncio
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
......@@ -170,7 +170,7 @@ FLUX_DEFAULT_PROMPT = """
},
"10": {
"inputs": {
"vae_name": "ae.sft"
"vae_name": "ae.safetensors"
},
"class_type": "VAELoader"
},
......@@ -184,7 +184,7 @@ FLUX_DEFAULT_PROMPT = """
},
"12": {
"inputs": {
"unet_name": "flux1-dev.sft",
"unet_name": "flux1-dev.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader"
......@@ -328,7 +328,7 @@ class ImageGenerationPayload(BaseModel):
flux_fp8_clip: Optional[bool] = None
def comfyui_generate_image(
async def comfyui_generate_image(
model: str, payload: ImageGenerationPayload, client_id, base_url
):
ws_url = base_url.replace("http://", "ws://").replace("https://", "wss://")
......@@ -397,7 +397,9 @@ def comfyui_generate_image(
return None
try:
images = get_images(ws, comfyui_prompt, client_id, base_url)
images = await asyncio.to_thread(
get_images, ws, comfyui_prompt, client_id, base_url
)
except Exception as e:
log.exception(f"Error while receiving images: {e}")
images = None
......
from fastapi import (
FastAPI,
Request,
Response,
HTTPException,
Depends,
status,
UploadFile,
File,
BackgroundTasks,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, ConfigDict
import os
import re
import copy
import random
import requests
import json
import uuid
import aiohttp
import asyncio
import logging
import time
from urllib.parse import urlparse
from typing import Optional, List, Union
from typing import Optional, Union
from starlette.background import BackgroundTask
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
get_current_user,
get_verified_user,
get_admin_user,
)
from utils.task import prompt_template
from config import (
SRC_LOG_LEVELS,
......@@ -53,7 +42,12 @@ from config import (
UPLOAD_DIR,
AppConfig,
)
from utils.misc import calculate_sha256, add_or_update_system_message
from utils.misc import (
calculate_sha256,
apply_model_params_to_body_ollama,
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
......@@ -120,7 +114,7 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)):
class UrlUpdateForm(BaseModel):
urls: List[str]
urls: list[str]
@app.post("/urls/update")
......@@ -183,7 +177,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True):
res = await r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
......@@ -238,7 +232,7 @@ async def get_all_models():
async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_verified_user)
):
if url_idx == None:
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
......@@ -269,7 +263,7 @@ async def get_ollama_tags(
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
......@@ -282,8 +276,7 @@ async def get_ollama_tags(
@app.get("/api/version/{url_idx}")
async def get_ollama_versions(url_idx: Optional[int] = None):
if app.state.config.ENABLE_OLLAMA_API:
if url_idx == None:
if url_idx is None:
# returns lowest version
tasks = [
fetch_url(f"{url}/api/version")
......@@ -323,7 +316,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
......@@ -346,8 +339,6 @@ async def pull_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
# Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
......@@ -367,7 +358,7 @@ async def push_model(
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
if url_idx == None:
if url_idx is None:
if form_data.name in app.state.MODELS:
url_idx = app.state.MODELS[form_data.name]["urls"][0]
else:
......@@ -417,7 +408,7 @@ async def copy_model(
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
if url_idx == None:
if url_idx is None:
if form_data.source in app.state.MODELS:
url_idx = app.state.MODELS[form_data.source]["urls"][0]
else:
......@@ -428,13 +419,13 @@ async def copy_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
r = requests.request(
method="POST",
url=f"{url}/api/copy",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
log.debug(f"r.text: {r.text}")
......@@ -448,7 +439,7 @@ async def copy_model(
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
......@@ -464,7 +455,7 @@ async def delete_model(
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
if url_idx == None:
if url_idx is None:
if form_data.name in app.state.MODELS:
url_idx = app.state.MODELS[form_data.name]["urls"][0]
else:
......@@ -476,12 +467,12 @@ async def delete_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
r = requests.request(
method="DELETE",
url=f"{url}/api/delete",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
log.debug(f"r.text: {r.text}")
......@@ -495,7 +486,7 @@ async def delete_model(
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
......@@ -516,12 +507,12 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
r = requests.request(
method="POST",
url=f"{url}/api/show",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
return r.json()
......@@ -533,7 +524,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
......@@ -556,7 +547,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
if url_idx == None:
if url_idx is None:
model = form_data.model
if ":" not in model:
......@@ -573,12 +564,12 @@ async def generate_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
return r.json()
......@@ -590,7 +581,7 @@ async def generate_embeddings(
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
......@@ -603,10 +594,9 @@ def generate_ollama_embeddings(
form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None,
):
log.info(f"generate_ollama_embeddings {form_data}")
if url_idx == None:
if url_idx is None:
model = form_data.model
if ":" not in model:
......@@ -623,12 +613,12 @@ def generate_ollama_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try:
r.raise_for_status()
data = r.json()
......@@ -638,7 +628,7 @@ def generate_ollama_embeddings(
if "embedding" in data:
return data["embedding"]
else:
raise "Something went wrong :/"
raise Exception("Something went wrong :/")
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
......@@ -647,16 +637,16 @@ def generate_ollama_embeddings(
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
except Exception:
error_detail = f"Ollama: {e}"
raise error_detail
raise Exception(error_detail)
class GenerateCompletionForm(BaseModel):
model: str
prompt: str
images: Optional[List[str]] = None
images: Optional[list[str]] = None
format: Optional[str] = None
options: Optional[dict] = None
system: Optional[str] = None
......@@ -674,8 +664,7 @@ async def generate_completion(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
if url_idx == None:
if url_idx is None:
model = form_data.model
if ":" not in model:
......@@ -700,12 +689,12 @@ async def generate_completion(
class ChatMessage(BaseModel):
role: str
content: str
images: Optional[List[str]] = None
images: Optional[list[str]] = None
class GenerateChatCompletionForm(BaseModel):
model: str
messages: List[ChatMessage]
messages: list[ChatMessage]
format: Optional[str] = None
options: Optional[dict] = None
template: Optional[str] = None
......@@ -713,6 +702,18 @@ class GenerateChatCompletionForm(BaseModel):
keep_alive: Optional[Union[int, str]] = None
def get_ollama_url(url_idx: Optional[int], model: str):
if url_idx is None:
if model not in app.state.MODELS:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
)
url_idx = random.choice(app.state.MODELS[model]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
return url
@app.post("/api/chat")
@app.post("/api/chat/{url_idx}")
async def generate_chat_completion(
......@@ -720,12 +721,7 @@ async def generate_chat_completion(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
payload = {
**form_data.model_dump(exclude_none=True, exclude=["metadata"]),
......@@ -740,185 +736,21 @@ async def generate_chat_completion(
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
params = model_info.params.model_dump()
if model_info.params:
if params:
if payload.get("options") is None:
payload["options"] = {}
if (
model_info.params.get("mirostat", None)
and payload["options"].get("mirostat") is None
):
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
if (
model_info.params.get("mirostat_eta", None)
and payload["options"].get("mirostat_eta") is None
):
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
if (
model_info.params.get("mirostat_tau", None)
and payload["options"].get("mirostat_tau") is None
):
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
if (
model_info.params.get("num_ctx", None)
and payload["options"].get("num_ctx") is None
):
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
if (
model_info.params.get("num_batch", None)
and payload["options"].get("num_batch") is None
):
payload["options"]["num_batch"] = model_info.params.get(
"num_batch", None
)
if (
model_info.params.get("num_keep", None)
and payload["options"].get("num_keep") is None
):
payload["options"]["num_keep"] = model_info.params.get("num_keep", None)
if (
model_info.params.get("repeat_last_n", None)
and payload["options"].get("repeat_last_n") is None
):
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
if (
model_info.params.get("frequency_penalty", None)
and payload["options"].get("frequency_penalty") is None
):
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
if (
model_info.params.get("temperature", None) is not None
and payload["options"].get("temperature") is None
):
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
if (
model_info.params.get("seed", None) is not None
and payload["options"].get("seed") is None
):
payload["options"]["seed"] = model_info.params.get("seed", None)
if (
model_info.params.get("stop", None)
and payload["options"].get("stop") is None
):
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if (
model_info.params.get("tfs_z", None)
and payload["options"].get("tfs_z") is None
):
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if (
model_info.params.get("max_tokens", None)
and payload["options"].get("max_tokens") is None
):
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
if (
model_info.params.get("top_k", None)
and payload["options"].get("top_k") is None
):
payload["options"]["top_k"] = model_info.params.get("top_k", None)
if (
model_info.params.get("top_p", None)
and payload["options"].get("top_p") is None
):
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if (
model_info.params.get("min_p", None)
and payload["options"].get("min_p") is None
):
payload["options"]["min_p"] = model_info.params.get("min_p", None)
if (
model_info.params.get("use_mmap", None)
and payload["options"].get("use_mmap") is None
):
payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
if (
model_info.params.get("use_mlock", None)
and payload["options"].get("use_mlock") is None
):
payload["options"]["use_mlock"] = model_info.params.get(
"use_mlock", None
)
if (
model_info.params.get("num_thread", None)
and payload["options"].get("num_thread") is None
):
payload["options"]["num_thread"] = model_info.params.get(
"num_thread", None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
if payload.get("messages"):
payload["messages"] = add_or_update_system_message(
system, payload["messages"]
payload["options"] = apply_model_params_to_body_ollama(
params, payload["options"]
)
payload = apply_model_system_prompt_to_body(params, payload, user)
if url_idx == None:
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
log.debug(payload)
......@@ -940,7 +772,7 @@ class OpenAIChatMessage(BaseModel):
class OpenAIChatCompletionForm(BaseModel):
model: str
messages: List[OpenAIChatMessage]
messages: list[OpenAIChatMessage]
model_config = ConfigDict(extra="allow")
......@@ -952,83 +784,28 @@ async def generate_openai_chat_completion(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
form_data = OpenAIChatCompletionForm(**form_data)
payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
completion_form = OpenAIChatCompletionForm(**form_data)
payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
if "metadata" in payload:
del payload["metadata"]
model_id = form_data.model
model_id = completion_form.model
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
params = model_info.params.model_dump()
if model_info.params:
payload["temperature"] = model_info.params.get("temperature", None)
payload["top_p"] = model_info.params.get("top_p", None)
payload["max_tokens"] = model_info.params.get("max_tokens", None)
payload["frequency_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["seed"] = model_info.params.get("seed", None)
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if params:
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
if url_idx == None:
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
url = get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}")
return await post_streaming_url(
......@@ -1044,7 +821,7 @@ async def get_openai_models(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
if url_idx == None:
if url_idx is None:
models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER:
......@@ -1099,7 +876,7 @@ async def get_openai_models(
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
except Exception:
error_detail = f"Ollama: {e}"
raise HTTPException(
......@@ -1125,7 +902,6 @@ def parse_huggingface_url(hf_url):
path_components = parsed_url.path.split("/")
# Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1]
return model_file
......@@ -1190,7 +966,6 @@ async def download_model(
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
if not any(form_data.url.startswith(host) for host in allowed_hosts):
......@@ -1199,7 +974,7 @@ async def download_model(
detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
)
if url_idx == None:
if url_idx is None:
url_idx = 0
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
......@@ -1222,7 +997,7 @@ def upload_model(
url_idx: Optional[int] = None,
user=Depends(get_admin_user),
):
if url_idx == None:
if url_idx is None:
url_idx = 0
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
......
......@@ -17,7 +17,10 @@ from utils.utils import (
get_verified_user,
get_admin_user,
)
from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body
from utils.misc import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from config import (
SRC_LOG_LEVELS,
......@@ -30,7 +33,7 @@ from config import (
MODEL_FILTER_LIST,
AppConfig,
)
from typing import List, Optional, Literal, overload
from typing import Optional, Literal, overload
import hashlib
......@@ -86,11 +89,11 @@ async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user
class UrlsUpdateForm(BaseModel):
urls: List[str]
urls: list[str]
class KeysUpdateForm(BaseModel):
keys: List[str]
keys: list[str]
@app.get("/urls")
......@@ -368,7 +371,7 @@ async def generate_chat_completion(
payload["model"] = model_info.base_model_id
params = model_info.params.model_dump()
payload = apply_model_params_to_body(params, payload)
payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user)
model = app.state.MODELS[payload.get("model")]
......
......@@ -13,7 +13,7 @@ import os, shutil, logging, re
from datetime import datetime
from pathlib import Path
from typing import List, Union, Sequence, Iterator, Any
from typing import Union, Sequence, Iterator, Any
from chromadb.utils.batch_utils import create_batches
from langchain_core.documents import Document
......@@ -376,7 +376,7 @@ async def update_reranking_config(
try:
app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
update_reranking_model(app.state.config.RAG_RERANKING_MODEL, True)
return {
"status": True,
......@@ -439,7 +439,7 @@ class ChunkParamUpdateForm(BaseModel):
class YoutubeLoaderConfig(BaseModel):
language: List[str]
language: list[str]
translation: Optional[str] = None
......@@ -642,7 +642,7 @@ def query_doc_handler(
class QueryCollectionsForm(BaseModel):
collection_names: List[str]
collection_names: list[str]
query: str
k: Optional[int] = None
r: Optional[float] = None
......@@ -1021,7 +1021,7 @@ class TikaLoader:
self.file_path = file_path
self.mime_type = mime_type
def load(self) -> List[Document]:
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
......@@ -1185,7 +1185,7 @@ def store_doc(
f.close()
f = open(file_path, "rb")
if collection_name == None:
if collection_name is None:
collection_name = calculate_sha256(f)[:63]
f.close()
......@@ -1238,7 +1238,7 @@ def process_doc(
f = open(file_path, "rb")
collection_name = form_data.collection_name
if collection_name == None:
if collection_name is None:
collection_name = calculate_sha256(f)[:63]
f.close()
......@@ -1296,7 +1296,7 @@ def store_text(
):
collection_name = form_data.collection_name
if collection_name == None:
if collection_name is None:
collection_name = calculate_sha256_string(form_data.content)
result = store_text_in_vector_db(
......@@ -1339,7 +1339,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None:
if doc is None:
doc = Documents.insert_new_doc(
user.id,
DocumentForm(
......
import logging
from typing import List, Optional
from typing import Optional
import requests
from apps.rag.search.main import SearchResult, get_filtered_results
......@@ -10,7 +10,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
......
import logging
from typing import List, Optional
from typing import Optional
from apps.rag.search.main import SearchResult, get_filtered_results
from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS
......@@ -9,7 +9,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(
query: str, count: int, filter_list: Optional[List[str]] = None
query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
......@@ -18,7 +18,7 @@ def search_duckduckgo(
count (int): The number of results to return
Returns:
List[SearchResult]: A list of search results
list[SearchResult]: A list of search results
"""
# Use the DDGS context manager to create a DDGS object
with DDGS() as ddgs:
......
import json
import logging
from typing import List, Optional
from typing import Optional
import requests
from apps.rag.search.main import SearchResult, get_filtered_results
......@@ -15,7 +15,7 @@ def search_google_pse(
search_engine_id: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
......
......@@ -17,7 +17,7 @@ def search_jina(query: str, count: int) -> list[SearchResult]:
count (int): The number of results to return
Returns:
List[SearchResult]: A list of search results
list[SearchResult]: A list of search results
"""
jina_search_endpoint = "https://s.jina.ai/"
headers = {
......
import logging
import requests
from typing import List, Optional
from typing import Optional
from apps.rag.search.main import SearchResult, get_filtered_results
from config import SRC_LOG_LEVELS
......@@ -14,9 +14,9 @@ def search_searxng(
query_url: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
filter_list: Optional[list[str]] = None,
**kwargs,
) -> List[SearchResult]:
) -> list[SearchResult]:
"""
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
......@@ -31,10 +31,10 @@ def search_searxng(
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
categories: (Optional[list[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
Returns:
List[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
Raise:
requests.exceptions.RequestException: If a request error occurs during the search process.
......
import json
import logging
from typing import List, Optional
from typing import Optional
import requests
from apps.rag.search.main import SearchResult, get_filtered_results
......@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper(
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
......
import json
import logging
from typing import List, Optional
from typing import Optional
import requests
from urllib.parse import urlencode
......@@ -19,7 +19,7 @@ def search_serply(
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US",
filter_list: Optional[List[str]] = None,
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
......
import json
import logging
from typing import List, Optional
from typing import Optional
import requests
from apps.rag.search.main import SearchResult, get_filtered_results
......@@ -14,7 +14,7 @@ def search_serpstack(
api_key: str,
query: str,
count: int,
filter_list: Optional[List[str]] = None,
filter_list: Optional[list[str]] = None,
https_enabled: bool = True,
) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects.
......
......@@ -17,7 +17,7 @@ def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
query (str): The query to search for
Returns:
List[SearchResult]: A list of search results
list[SearchResult]: A list of search results
"""
url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key}
......
......@@ -2,7 +2,7 @@ import os
import logging
import requests
from typing import List, Union
from typing import Union
from apps.ollama.main import (
generate_ollama_embeddings,
......@@ -142,7 +142,7 @@ def merge_and_sort_query_results(query_results, k, reverse=False):
def query_collection(
collection_names: List[str],
collection_names: list[str],
query: str,
embedding_function,
k: int,
......@@ -157,13 +157,13 @@ def query_collection(
embedding_function=embedding_function,
)
results.append(result)
except:
except Exception:
pass
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
collection_names: List[str],
collection_names: list[str],
query: str,
embedding_function,
k: int,
......@@ -182,7 +182,7 @@ def query_collection_with_hybrid_search(
r=r,
)
results.append(result)
except:
except Exception:
pass
return merge_and_sort_query_results(results, k=k, reverse=True)
......@@ -411,7 +411,7 @@ class ChromaRetriever(BaseRetriever):
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
) -> list[Document]:
query_embeddings = self.embedding_function(query)
results = self.collection.query(
......
......@@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id
from utils.misc import (
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
apply_model_params_to_body,
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
......@@ -46,6 +46,7 @@ from config import (
AppConfig,
OAUTH_USERNAME_CLAIM,
OAUTH_PICTURE_CLAIM,
OAUTH_EMAIL_CLAIM,
)
from apps.socket.main import get_event_call, get_event_emitter
......@@ -84,6 +85,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app.state.MODELS = {}
app.state.TOOLS = {}
......@@ -289,7 +291,7 @@ async def generate_function_chat_completion(form_data, user):
form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump()
form_data = apply_model_params_to_body(params, form_data)
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user)
pipe_id = get_pipe_id(form_data)
......
......@@ -140,7 +140,7 @@ class AuthsTable:
return None
else:
return None
except:
except Exception:
return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
......@@ -152,7 +152,7 @@ class AuthsTable:
try:
user = Users.get_user_by_api_key(api_key)
return user if user else None
except:
except Exception:
return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
......@@ -163,7 +163,7 @@ class AuthsTable:
if auth:
user = Users.get_user_by_id(auth.id)
return user
except:
except Exception:
return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
......@@ -174,7 +174,7 @@ class AuthsTable:
)
db.commit()
return True if result == 1 else False
except:
except Exception:
return False
def update_email_by_id(self, id: str, email: str) -> bool:
......@@ -183,7 +183,7 @@ class AuthsTable:
result = db.query(Auth).filter_by(id=id).update({"email": email})
db.commit()
return True if result == 1 else False
except:
except Exception:
return False
def delete_auth_by_id(self, id: str) -> bool:
......@@ -200,7 +200,7 @@ class AuthsTable:
return True
else:
return False
except:
except Exception:
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment