Commit 4ff17acc authored by Jun Siang Cheah's avatar Jun Siang Cheah
Browse files

Merge remote-tracking branch 'upstream/dev' into feat/oauth

parents f49d814d 9928114c
...@@ -5,6 +5,81 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,81 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.4] - 2024-06-12
### Fixed
- **🔒 Mixed Content with HTTPS Issue**: Resolved a problem where mixed content (HTTP and HTTPS) was causing security warnings and blocking resources on HTTPS sites.
- **🔍 Web Search Issue**: Addressed the problem where web search functionality was not working correctly. The 'ENABLE_RAG_LOCAL_WEB_FETCH' option has been reintroduced to restore proper web searching capabilities.
- **💾 RAG Template Not Being Saved**: Fixed an issue where the RAG template was not being saved correctly, ensuring your custom templates are now preserved as expected.
## [0.3.3] - 2024-06-12
### Added
- **🛠️ Native Python Function Calling**: Introducing native Python function calling within Open WebUI. We’ve also included a built-in code editor to seamlessly develop and integrate function code within the 'Tools' workspace. With this, you can significantly enhance your LLM’s capabilities by creating custom RAG pipelines, web search tools, and even agent-like features such as sending Discord messages.
- **🌐 DuckDuckGo Integration**: Added DuckDuckGo as a web search provider, giving you more search options.
- **🌏 Enhanced Translations**: Improved translations for Vietnamese and Chinese languages, making the interface more accessible.
### Fixed
- **🔗 Web Search URL Error Handling**: Fixed the issue where a single URL error would disrupt the data loading process in Web Search mode. Now, such errors will be handled gracefully to ensure uninterrupted data loading.
- **🖥️ Frontend Responsiveness**: Resolved the problem where the frontend would stop responding if the backend encounters an error while downloading a model. Improved error handling to maintain frontend stability.
- **🔧 Dependency Issues in pip**: Fixed issues related to pip installations, ensuring all dependencies are correctly managed to prevent installation errors.
## [0.3.2] - 2024-06-10
### Added
- **🔍 Web Search Query Status**: The web search query will now persist in the results section to aid in easier debugging and tracking of search queries.
- **🌐 New Web Search Provider**: We have added Serply as a new option for web search providers, giving you more choices for your search needs.
- **🌏 Improved Translations**: We've enhanced translations for Chinese and Portuguese.
### Fixed
- **🎤 Audio File Upload Issue**: The bug that prevented audio files from being uploaded in chat input has been fixed, ensuring smooth communication.
- **💬 Message Input Handling**: Improved the handling of message inputs by instantly clearing images and text after sending, along with immediate visual indications when a response message is loading, enhancing user feedback.
- **⚙️ Parameter Registration and Validation**: Fixed the issue where parameters were not registering in certain cases and addressed the problem where users were unable to save due to invalid input errors.
## [0.3.1] - 2024-06-09
### Fixed
- **💬 Chat Functionality**: Resolved the issue where chat functionality was not working for specific models.
## [0.3.0] - 2024-06-09
### Added
- **📚 Knowledge Support for Models**: Attach documents directly to models from the models workspace, enhancing the information available to each model.
- **🎙️ Hands-Free Voice Call Feature**: Initiate voice calls without needing to use your hands, making interactions more seamless.
- **📹 Video Call Feature**: Enable video calls with supported vision models like Llava and GPT-4o, adding a visual dimension to your communications.
- **🎛️ Enhanced UI for Voice Recording**: Improved user interface for the voice recording feature, making it more intuitive and user-friendly.
- **🌐 External STT Support**: Now support for external Speech-To-Text services, providing more flexibility in choosing your STT provider.
- **⚙️ Unified Settings**: Consolidated settings including document settings under a new admin settings section for easier management.
- **🌑 Dark Mode Splash Screen**: A new splash screen for dark mode, ensuring a consistent and visually appealing experience for dark mode users.
- **📥 Upload Pipeline**: Directly upload pipelines from the admin settings > pipelines section, streamlining the pipeline management process.
- **🌍 Improved Language Support**: Enhanced support for Chinese and Ukrainian languages, better catering to a global user base.
### Fixed
- **🛠️ Playground Issue**: Fixed the playground not functioning properly, ensuring a smoother user experience.
- **🔥 Temperature Parameter Issue**: Corrected the issue where the temperature value '0' was not being passed correctly.
- **📝 Prompt Input Clearing**: Resolved prompt input textarea not being cleared right away, ensuring a clean slate for new inputs.
- **✨ Various UI Styling Issues**: Fixed numerous user interface styling problems for a more cohesive look.
- **👥 Active Users Display**: Fixed active users showing active sessions instead of actual users, now reflecting accurate user activity.
- **🌐 Community Platform Compatibility**: The Community Platform is back online and fully compatible with Open WebUI.
### Changed
- **📝 RAG Implementation**: Updated the RAG (Retrieval-Augmented Generation) implementation to use a system prompt for context, instead of overriding the user's prompt.
- **🔄 Settings Relocation**: Moved Models, Connections, Audio, and Images settings to the admin settings for better organization.
- **✍️ Improved Title Generation**: Enhanced the default prompt for title generation, yielding better results.
- **🔧 Backend Task Management**: Tasks like title generation and search query generation are now managed on the backend side and controlled only by the admin.
- **🔍 Editable Search Query Prompt**: You can now edit the search query generation prompt, offering more control over how queries are generated.
- **📏 Prompt Length Threshold**: Set the prompt length threshold for search query generation from the admin settings, giving more customization options.
- **📣 Settings Consolidation**: Merged the Banners admin setting with the Interface admin setting for a more streamlined settings area.
## [0.2.5] - 2024-06-05 ## [0.2.5] - 2024-06-05
### Added ### Added
......
...@@ -29,11 +29,15 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature- ...@@ -29,11 +29,15 @@ Open WebUI is an [extensible](https://github.com/open-webui/pipelines), feature-
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction. - ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
- 🎤📹 **Hands-Free Voice/Video Call**: Experience seamless communication with integrated hands-free voice and video call features, allowing for a more dynamic and interactive chat environment.
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration. - 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
- 🐍 **Native Python Function Calling Tool**: Enhance your LLMs with built-in code editor support in the tools workspace. Bring Your Own Function (BYOF) by simply adding your pure Python functions, enabling seamless integration with LLMs.
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query. - 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, and `serper`, and inject the results directly into your chat experience. - 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, `serper`, `Serply`, `DuckDuckGo` and `TavilySearch` and inject the results directly into your chat experience.
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions. - 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
...@@ -146,10 +150,19 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa ...@@ -146,10 +150,19 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa
In the last part of the command, replace `open-webui` with your container name if it is different. In the last part of the command, replace `open-webui` with your container name if it is different.
### Moving from Ollama WebUI to Open WebUI
Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/migration/). Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/migration/).
### Using the Dev Branch 🌙
> [!WARNING]
> The `:dev` branch contains the latest unstable features and changes. Use it at your own risk as it may have bugs or incomplete features.
If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this:
```bash
docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:dev
```
## What's Next? 🌟 ## What's Next? 🌟
Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/). Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/).
......
...@@ -18,6 +18,10 @@ If you're experiencing connection issues, it’s often due to the WebUI docker c ...@@ -18,6 +18,10 @@ If you're experiencing connection issues, it’s often due to the WebUI docker c
docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
``` ```
### Error on Slow Reponses for Ollama
Open WebUI has a default timeout of 5 minutes for Ollama to finish generating the response. If needed, this can be adjusted via the environment variable AIOHTTP_CLIENT_TIMEOUT, which sets the timeout in seconds.
### General Connection Errors ### General Connection Errors
**Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates. **Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates.
......
...@@ -17,13 +17,12 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -17,13 +17,12 @@ from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from pydantic import BaseModel from pydantic import BaseModel
import uuid
import requests import requests
import hashlib import hashlib
from pathlib import Path from pathlib import Path
import json import json
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
decode_token, decode_token,
...@@ -41,10 +40,15 @@ from config import ( ...@@ -41,10 +40,15 @@ from config import (
WHISPER_MODEL_DIR, WHISPER_MODEL_DIR,
WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_AUTO_UPDATE,
DEVICE_TYPE, DEVICE_TYPE,
AUDIO_OPENAI_API_BASE_URL, AUDIO_STT_OPENAI_API_BASE_URL,
AUDIO_OPENAI_API_KEY, AUDIO_STT_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL, AUDIO_TTS_OPENAI_API_BASE_URL,
AUDIO_OPENAI_API_VOICE, AUDIO_TTS_OPENAI_API_KEY,
AUDIO_STT_ENGINE,
AUDIO_STT_MODEL,
AUDIO_TTS_ENGINE,
AUDIO_TTS_MODEL,
AUDIO_TTS_VOICE,
AppConfig, AppConfig,
) )
...@@ -61,10 +65,17 @@ app.add_middleware( ...@@ -61,10 +65,17 @@ app.add_middleware(
) )
app.state.config = AppConfig() app.state.config = AppConfig()
app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
# setting device type for whisper model # setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
...@@ -74,41 +85,101 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") ...@@ -74,41 +85,101 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
class OpenAIConfigUpdateForm(BaseModel): class TTSConfigForm(BaseModel):
url: str OPENAI_API_BASE_URL: str
key: str OPENAI_API_KEY: str
model: str ENGINE: str
speaker: str MODEL: str
VOICE: str
class STTConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
class AudioConfigUpdateForm(BaseModel):
tts: TTSConfigForm
stt: STTConfigForm
from pydub import AudioSegment
from pydub.utils import mediainfo
def is_mp4_audio(file_path):
"""Check if the given file is an MP4 audio file."""
if not os.path.isfile(file_path):
print(f"File not found: {file_path}")
return False
info = mediainfo(file_path)
if (
info.get("codec_name") == "aac"
and info.get("codec_type") == "audio"
and info.get("codec_tag_string") == "mp4a"
):
return True
return False
def convert_mp4_to_wav(file_path, output_path):
"""Convert MP4 audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format="mp4")
audio.export(output_path, format="wav")
print(f"Converted {file_path} to {output_path}")
@app.get("/config") @app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)): async def get_audio_config(user=Depends(get_admin_user)):
return { return {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, "tts": {
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, "ENGINE": app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
},
} }
@app.post("/config/update") @app.post("/config/update")
async def update_openai_config( async def update_audio_config(
form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user) form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.key == "": app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
app.state.config.TTS_ENGINE = form_data.tts.ENGINE
app.state.config.TTS_MODEL = form_data.tts.MODEL
app.state.config.TTS_VOICE = form_data.tts.VOICE
app.state.config.OPENAI_API_BASE_URL = form_data.url app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = form_data.key app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
app.state.config.OPENAI_API_MODEL = form_data.model app.state.config.STT_ENGINE = form_data.stt.ENGINE
app.state.config.OPENAI_API_VOICE = form_data.speaker app.state.config.STT_MODEL = form_data.stt.MODEL
return { return {
"status": True, "tts": {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, "ENGINE": app.state.config.TTS_ENGINE,
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, "MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
},
} }
...@@ -125,13 +196,21 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -125,13 +196,21 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
try:
body = body.decode("utf-8")
body = json.loads(body)
body["model"] = app.state.config.TTS_MODEL
body = json.dumps(body).encode("utf-8")
except Exception as e:
pass
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech", url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
data=body, data=body,
headers=headers, headers=headers,
stream=True, stream=True,
...@@ -181,13 +260,23 @@ def transcribe( ...@@ -181,13 +260,23 @@ def transcribe(
) )
try: try:
filename = file.filename ext = file.filename.split(".")[-1]
file_path = f"{UPLOAD_DIR}/{filename}"
id = uuid.uuid4()
filename = f"{id}.{ext}"
file_dir = f"{CACHE_DIR}/audio/transcriptions"
os.makedirs(file_dir, exist_ok=True)
file_path = f"{file_dir}/{filename}"
print(filename)
contents = file.file.read() contents = file.file.read()
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(contents) f.write(contents)
f.close() f.close()
if app.state.config.STT_ENGINE == "":
whisper_kwargs = { whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL, "model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type, "device": whisper_device_type,
...@@ -215,7 +304,66 @@ def transcribe( ...@@ -215,7 +304,66 @@ def transcribe(
transcript = "".join([segment.text for segment in list(segments)]) transcript = "".join([segment.text for segment in list(segments)])
return {"text": transcript.strip()} data = {"text": transcript.strip()}
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
return data
elif app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
print("is_mp4_audio")
os.rename(file_path, file_path.replace(".wav", ".mp4"))
# Convert MP4 audio file to WAV format
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
files = {"file": (filename, open(file_path, "rb"))}
data = {"model": "whisper-1"}
print(files, data)
r = None
try:
r = requests.post(
url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
headers=headers,
files=files,
data=data,
)
r.raise_for_status()
data = r.json()
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r != None else 500,
detail=error_detail,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
......
...@@ -41,13 +41,12 @@ from utils.utils import ( ...@@ -41,13 +41,12 @@ from utils.utils import (
get_admin_user, get_admin_user,
) )
from utils.models import get_model_id_from_custom_model_id
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
OLLAMA_BASE_URLS, OLLAMA_BASE_URLS,
ENABLE_OLLAMA_API, ENABLE_OLLAMA_API,
AIOHTTP_CLIENT_TIMEOUT,
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
UPLOAD_DIR, UPLOAD_DIR,
...@@ -156,7 +155,9 @@ async def cleanup_response( ...@@ -156,7 +155,9 @@ async def cleanup_response(
async def post_streaming_url(url: str, payload: str): async def post_streaming_url(url: str, payload: str):
r = None r = None
try: try:
session = aiohttp.ClientSession(trust_env=True) session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
r = await session.post(url, data=payload) r = await session.post(url, data=payload)
r.raise_for_status() r.raise_for_status()
...@@ -728,7 +729,6 @@ async def generate_chat_completion( ...@@ -728,7 +729,6 @@ async def generate_chat_completion(
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
print(model_info)
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
...@@ -754,6 +754,14 @@ async def generate_chat_completion( ...@@ -754,6 +754,14 @@ async def generate_chat_completion(
if model_info.params.get("num_ctx", None): if model_info.params.get("num_ctx", None):
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
if model_info.params.get("num_batch", None):
payload["options"]["num_batch"] = model_info.params.get(
"num_batch", None
)
if model_info.params.get("num_keep", None):
payload["options"]["num_keep"] = model_info.params.get("num_keep", None)
if model_info.params.get("repeat_last_n", None): if model_info.params.get("repeat_last_n", None):
payload["options"]["repeat_last_n"] = model_info.params.get( payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None "repeat_last_n", None
...@@ -764,7 +772,7 @@ async def generate_chat_completion( ...@@ -764,7 +772,7 @@ async def generate_chat_completion(
"frequency_penalty", None "frequency_penalty", None
) )
if model_info.params.get("temperature", None): if model_info.params.get("temperature", None) is not None:
payload["options"]["temperature"] = model_info.params.get( payload["options"]["temperature"] = model_info.params.get(
"temperature", None "temperature", None
) )
...@@ -849,9 +857,14 @@ async def generate_chat_completion( ...@@ -849,9 +857,14 @@ async def generate_chat_completion(
# TODO: we should update this part once Ollama supports other types # TODO: we should update this part once Ollama supports other types
class OpenAIChatMessageContent(BaseModel):
type: str
model_config = ConfigDict(extra="allow")
class OpenAIChatMessage(BaseModel): class OpenAIChatMessage(BaseModel):
role: str role: str
content: str content: Union[str, OpenAIChatMessageContent]
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
...@@ -879,7 +892,6 @@ async def generate_openai_chat_completion( ...@@ -879,7 +892,6 @@ async def generate_openai_chat_completion(
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
print(model_info)
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
......
...@@ -345,46 +345,34 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use ...@@ -345,46 +345,34 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
) )
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.post("/chat/completions")
async def proxy(path: str, request: Request, user=Depends(get_verified_user)): @app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
idx = 0 idx = 0
payload = {**form_data}
body = await request.body() model_id = form_data.get("model")
# TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
payload = None
try:
if "chat/completions" in path:
body = body.decode("utf-8")
body = json.loads(body)
payload = {**body}
model_id = body.get("model")
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
print(model_info)
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump() model_info.params = model_info.params.model_dump()
if model_info.params: if model_info.params:
if model_info.params.get("temperature", None): if model_info.params.get("temperature", None) is not None:
payload["temperature"] = int( payload["temperature"] = float(model_info.params.get("temperature"))
model_info.params.get("temperature")
)
if model_info.params.get("top_p", None): if model_info.params.get("top_p", None):
payload["top_p"] = int(model_info.params.get("top_p", None)) payload["top_p"] = int(model_info.params.get("top_p", None))
if model_info.params.get("max_tokens", None): if model_info.params.get("max_tokens", None):
payload["max_tokens"] = int( payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
model_info.params.get("max_tokens", None)
)
if model_info.params.get("frequency_penalty", None): if model_info.params.get("frequency_penalty", None):
payload["frequency_penalty"] = int( payload["frequency_penalty"] = int(
...@@ -411,8 +399,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -411,8 +399,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
for message in payload["messages"]: for message in payload["messages"]:
if message.get("role") == "system": if message.get("role") == "system":
message["content"] = ( message["content"] = (
model_info.params.get("system", None) model_info.params.get("system", None) + message["content"]
+ message["content"]
) )
break break
else: else:
...@@ -423,11 +410,11 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -423,11 +410,11 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
"content": model_info.params.get("system", None), "content": model_info.params.get("system", None),
}, },
) )
else: else:
pass pass
model = app.state.MODELS[payload.get("model")] model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"] idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"): if "pipeline" in model and model.get("pipeline"):
...@@ -443,11 +430,71 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -443,11 +430,71 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
# Convert the modified body back to JSON # Convert the modified body back to JSON
payload = json.dumps(payload) payload = json.dumps(payload)
except json.JSONDecodeError as e: print(payload)
log.error("Error loading request body into a dictionary:", e)
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
print(payload) print(payload)
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
session = None
streaming = False
try:
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
data=payload,
headers=headers,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
if not streaming and session:
if r:
r.close()
await session.close()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0
body = await request.body()
url = app.state.config.OPENAI_API_BASE_URLS[idx] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx] key = app.state.config.OPENAI_API_KEYS[idx]
...@@ -466,7 +513,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -466,7 +513,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
r = await session.request( r = await session.request(
method=request.method, method=request.method,
url=target_url, url=target_url,
data=payload if payload else body, data=body,
headers=headers, headers=headers,
) )
......
...@@ -8,12 +8,15 @@ from fastapi import ( ...@@ -8,12 +8,15 @@ from fastapi import (
Form, Form,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import requests
import os, shutil, logging, re import os, shutil, logging, re
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Union, Sequence from typing import List, Union, Sequence, Iterator, Any
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
from langchain_core.documents import Document
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
WebBaseLoader, WebBaseLoader,
...@@ -30,6 +33,7 @@ from langchain_community.document_loaders import ( ...@@ -30,6 +33,7 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader, UnstructuredExcelLoader,
UnstructuredPowerPointLoader, UnstructuredPowerPointLoader,
YoutubeLoader, YoutubeLoader,
OutlookMessageLoader,
) )
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
...@@ -67,7 +71,9 @@ from apps.rag.search.main import SearchResult ...@@ -67,7 +71,9 @@ from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo
from apps.rag.search.tavily import search_tavily
from utils.misc import ( from utils.misc import (
calculate_sha256, calculate_sha256,
...@@ -113,6 +119,8 @@ from config import ( ...@@ -113,6 +119,8 @@ from config import (
SERPSTACK_API_KEY, SERPSTACK_API_KEY,
SERPSTACK_HTTPS, SERPSTACK_HTTPS,
SERPER_API_KEY, SERPER_API_KEY,
SERPLY_API_KEY,
TAVILY_API_KEY,
RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
RAG_EMBEDDING_OPENAI_BATCH_SIZE, RAG_EMBEDDING_OPENAI_BATCH_SIZE,
...@@ -165,6 +173,8 @@ app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY ...@@ -165,6 +173,8 @@ app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY app.state.config.SERPER_API_KEY = SERPER_API_KEY
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
...@@ -392,6 +402,8 @@ async def get_rag_config(user=Depends(get_admin_user)): ...@@ -392,6 +402,8 @@ async def get_rag_config(user=Depends(get_admin_user)):
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY, "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
"serpstack_https": app.state.config.SERPSTACK_HTTPS, "serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY, "serper_api_key": app.state.config.SERPER_API_KEY,
"serply_api_key": app.state.config.SERPLY_API_KEY,
"tavily_api_key": app.state.config.TAVILY_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
}, },
...@@ -419,6 +431,8 @@ class WebSearchConfig(BaseModel): ...@@ -419,6 +431,8 @@ class WebSearchConfig(BaseModel):
serpstack_api_key: Optional[str] = None serpstack_api_key: Optional[str] = None
serpstack_https: Optional[bool] = None serpstack_https: Optional[bool] = None
serper_api_key: Optional[str] = None serper_api_key: Optional[str] = None
serply_api_key: Optional[str] = None
tavily_api_key: Optional[str] = None
result_count: Optional[int] = None result_count: Optional[int] = None
concurrent_requests: Optional[int] = None concurrent_requests: Optional[int] = None
...@@ -469,6 +483,8 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ ...@@ -469,6 +483,8 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests form_data.web.search.concurrent_requests
...@@ -497,6 +513,8 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ ...@@ -497,6 +513,8 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY, "serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
"serpstack_https": app.state.config.SERPSTACK_HTTPS, "serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY, "serper_api_key": app.state.config.SERPER_API_KEY,
"serply_api_key": app.state.config.SERPLY_API_KEY,
"tavily_api_key": app.state.config.TAVILY_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
}, },
...@@ -693,7 +711,7 @@ def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True): ...@@ -693,7 +711,7 @@ def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
# Check if the URL is valid # Check if the URL is valid
if not validate_url(url): if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader( return SafeWebBaseLoader(
url, url,
verify_ssl=verify_ssl, verify_ssl=verify_ssl,
requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS, requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
...@@ -744,7 +762,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -744,7 +762,8 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
- BRAVE_SEARCH_API_KEY - BRAVE_SEARCH_API_KEY
- SERPSTACK_API_KEY - SERPSTACK_API_KEY
- SERPER_API_KEY - SERPER_API_KEY
- SERPLY_API_KEY
- TAVILY_API_KEY
Args: Args:
query (str): The query to search for query (str): The query to search for
""" """
...@@ -802,6 +821,26 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -802,6 +821,26 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
) )
else: else:
raise Exception("No SERPER_API_KEY found in environment variables") raise Exception("No SERPER_API_KEY found in environment variables")
elif engine == "serply":
if app.state.config.SERPLY_API_KEY:
return search_serply(
app.state.config.SERPLY_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
elif engine == "tavily":
if app.state.config.TAVILY_API_KEY:
return search_tavily(
app.state.config.TAVILY_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
else: else:
raise Exception("No search engine API key found in environment variables") raise Exception("No search engine API key found in environment variables")
...@@ -809,6 +848,9 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -809,6 +848,9 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
@app.post("/web/search") @app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)): def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
try: try:
logging.info(
f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
)
web_results = search_web( web_results = search_web(
app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
) )
...@@ -879,6 +921,13 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -879,6 +921,13 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
# ChromaDB does not like datetime formats
# for meta-data so convert them to string.
for metadata in metadatas:
for key, value in metadata.items():
if isinstance(value, datetime):
metadata[key] = str(value)
try: try:
if overwrite: if overwrite:
for collection in CHROMA_CLIENT.list_collections(): for collection in CHROMA_CLIENT.list_collections():
...@@ -965,6 +1014,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ...@@ -965,6 +1014,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"swift", "swift",
"vue", "vue",
"svelte", "svelte",
"msg",
] ]
if file_ext == "pdf": if file_ext == "pdf":
...@@ -999,6 +1049,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ...@@ -999,6 +1049,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"application/vnd.openxmlformats-officedocument.presentationml.presentation", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
] or file_ext in ["ppt", "pptx"]: ] or file_ext in ["ppt", "pptx"]:
loader = UnstructuredPowerPointLoader(file_path) loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
elif file_ext in known_source_ext or ( elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0 file_content_type and file_content_type.find("text/") >= 0
): ):
...@@ -1209,6 +1261,33 @@ def reset(user=Depends(get_admin_user)) -> bool: ...@@ -1209,6 +1261,33 @@ def reset(user=Depends(get_admin_user)) -> bool:
return True return True
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths:
try:
soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
text = soup.get_text(**self.bs_get_text_kwargs)
# Build metadata
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.error(f"Error loading {path}: {e}")
if ENV == "dev": if ENV == "dev":
@app.get("/ef") @app.get("/ef")
......
import logging
from apps.rag.search.main import SearchResult
from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args:
query (str): The query to search for
count (int): The number of results to return
Returns:
List[SearchResult]: A list of search results
"""
# Use the DDGS context manager to create a DDGS object
with DDGS() as ddgs:
# Use the ddgs.text() method to perform the search
ddgs_gen = ddgs.text(
query, safesearch="moderate", max_results=count, backend="api"
)
# Check if there are search results
if ddgs_gen:
# Convert the search results into a list
search_results = [r for r in ddgs_gen]
# Create an empty list to store the SearchResult objects
results = []
# Iterate over each search result
for result in search_results:
# Create a SearchResult object and append it to the results list
results.append(
SearchResult(
link=result["href"],
title=result.get("title"),
snippet=result.get("body"),
)
)
print(results)
# Return the list of search results
return results
...@@ -25,6 +25,7 @@ def search_searxng( ...@@ -25,6 +25,7 @@ def search_searxng(
Keyword Args: Keyword Args:
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string. language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''. time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided. categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
...@@ -37,6 +38,7 @@ def search_searxng( ...@@ -37,6 +38,7 @@ def search_searxng(
# Default values for optional parameters are provided as empty strings or None when not specified. # Default values for optional parameters are provided as empty strings or None when not specified.
language = kwargs.get("language", "en-US") language = kwargs.get("language", "en-US")
safesearch = kwargs.get("safesearch", "1")
time_range = kwargs.get("time_range", "") time_range = kwargs.get("time_range", "")
categories = "".join(kwargs.get("categories", [])) categories = "".join(kwargs.get("categories", []))
...@@ -44,6 +46,7 @@ def search_searxng( ...@@ -44,6 +46,7 @@ def search_searxng(
"q": query, "q": query,
"format": "json", "format": "json",
"pageno": 1, "pageno": 1,
"safesearch": safesearch,
"language": language, "language": language,
"time_range": time_range, "time_range": time_range,
"categories": categories, "categories": categories,
......
import json
import logging
import requests
from urllib.parse import urlencode
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serply(
api_key: str,
query: str,
count: int,
hl: str = "us",
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US",
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
api_key (str): A serply.io API key
query (str): The query to search for
hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
limit (int): The maximum number of results to return [10-100, defaults to 10]
"""
log.info("Searching with Serply")
url = "https://api.serply.io/v1/search/"
query_payload = {
"q": query,
"language": "en",
"num": limit,
"gl": proxy_location.upper(),
"hl": hl.lower(),
}
url = f"{url}{urlencode(query_payload)}"
headers = {
"X-API-KEY": api_key,
"X-User-Agent": device_type,
"User-Agent": "open-webui",
"X-Proxy-Location": proxy_location,
}
response = requests.request("GET", url, headers=headers)
response.raise_for_status()
json_response = response.json()
log.info(f"results from serply search: {json_response}")
results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:count]
]
import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Tavily Search API key
query (str): The query to search for
Returns:
List[SearchResult]: A list of search results
"""
url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key}
response = requests.post(url, json=data)
response.raise_for_status()
json_response = response.json()
raw_search_results = json_response.get("results", [])
return [
SearchResult(
link=result["url"],
title=result.get("title", ""),
snippet=result.get("content"),
)
for result in raw_search_results[:count]
]
{
"ads": [],
"ads_count": 0,
"answers": [],
"results": [
{
"title": "Apple",
"link": "https://www.apple.com/",
"description": "Discover the innovative world of Apple and shop everything iPhone, iPad, Apple Watch, Mac, and Apple TV, plus explore accessories, entertainment, ...",
"additional_links": [
{
"text": "AppleApplehttps://www.apple.com",
"href": "https://www.apple.com/"
}
],
"cite": {},
"subdomains": [
{
"title": "Support",
"link": "https://support.apple.com/",
"description": "SupportContact - iPhone Support - Billing and Subscriptions - Apple Repair"
},
{
"title": "Store",
"link": "https://www.apple.com/store",
"description": "StoreShop iPhone - Shop iPad - App Store - Shop Mac - ..."
},
{
"title": "Mac",
"link": "https://www.apple.com/mac/",
"description": "MacMacBook Air - MacBook Pro - iMac - Compare Mac models - Mac mini"
},
{
"title": "iPad",
"link": "https://www.apple.com/ipad/",
"description": "iPadShop iPad - iPad Pro - iPad Air - Compare iPad models - ..."
},
{
"title": "Watch",
"link": "https://www.apple.com/watch/",
"description": "WatchShop Apple Watch - Series 9 - SE - Ultra 2 - Nike - Hermès - ..."
}
],
"realPosition": 1
},
{
"title": "Apple",
"link": "https://www.apple.com/",
"description": "Discover the innovative world of Apple and shop everything iPhone, iPad, Apple Watch, Mac, and Apple TV, plus explore accessories, entertainment, ...",
"additional_links": [
{
"text": "AppleApplehttps://www.apple.com",
"href": "https://www.apple.com/"
}
],
"cite": {},
"realPosition": 2
},
{
"title": "Apple Inc.",
"link": "https://en.wikipedia.org/wiki/Apple_Inc.",
"description": "Apple Inc. (formerly Apple Computer, Inc.) is an American multinational corporation and technology company headquartered in Cupertino, California, ...",
"additional_links": [
{
"text": "Apple Inc.Wikipediahttps://en.wikipedia.org › wiki › Apple_Inc",
"href": "https://en.wikipedia.org/wiki/Apple_Inc."
},
{
"text": "",
"href": "https://en.wikipedia.org/wiki/Apple_Inc."
},
{
"text": "History",
"href": "https://en.wikipedia.org/wiki/History_of_Apple_Inc."
},
{
"text": "List of Apple products",
"href": "https://en.wikipedia.org/wiki/List_of_Apple_products"
},
{
"text": "Litigation involving Apple Inc.",
"href": "https://en.wikipedia.org/wiki/Litigation_involving_Apple_Inc."
},
{
"text": "Apple Park",
"href": "https://en.wikipedia.org/wiki/Apple_Park"
}
],
"cite": {
"domain": "https://en.wikipedia.org › wiki › Apple_Inc",
"span": " › wiki › Apple_Inc"
},
"realPosition": 3
},
{
"title": "Apple Inc. (AAPL) Company Profile & Facts",
"link": "https://finance.yahoo.com/quote/AAPL/profile/",
"description": "Apple Inc. designs, manufactures, and markets smartphones, personal computers, tablets, wearables, and accessories worldwide. The company offers iPhone, a line ...",
"additional_links": [
{
"text": "Apple Inc. (AAPL) Company Profile & FactsYahoo Financehttps://finance.yahoo.com › quote › AAPL › profile",
"href": "https://finance.yahoo.com/quote/AAPL/profile/"
}
],
"cite": {
"domain": "https://finance.yahoo.com › quote › AAPL › profile",
"span": " › quote › AAPL › profile"
},
"realPosition": 4
},
{
"title": "Apple Inc - Company Profile and News",
"link": "https://www.bloomberg.com/profile/company/AAPL:US",
"description": "Apple Inc. Apple Inc. designs, manufactures, and markets smartphones, personal computers, tablets, wearables and accessories, and sells a variety of related ...",
"additional_links": [
{
"text": "Apple Inc - Company Profile and NewsBloomberghttps://www.bloomberg.com › company › AAPL:US",
"href": "https://www.bloomberg.com/profile/company/AAPL:US"
},
{
"text": "",
"href": "https://www.bloomberg.com/profile/company/AAPL:US"
}
],
"cite": {
"domain": "https://www.bloomberg.com › company › AAPL:US",
"span": " › company › AAPL:US"
},
"realPosition": 5
},
{
"title": "Apple Inc. | History, Products, Headquarters, & Facts",
"link": "https://www.britannica.com/money/Apple-Inc",
"description": "May 22, 2024 — Apple Inc. is an American multinational technology company that revolutionized the technology sector through its innovation of computer ...",
"additional_links": [
{
"text": "Apple Inc. | History, Products, Headquarters, & FactsBritannicahttps://www.britannica.com › money › Apple-Inc",
"href": "https://www.britannica.com/money/Apple-Inc"
},
{
"text": "",
"href": "https://www.britannica.com/money/Apple-Inc"
}
],
"cite": {
"domain": "https://www.britannica.com › money › Apple-Inc",
"span": " › money › Apple-Inc"
},
"realPosition": 6
}
],
"shopping_ads": [],
"places": [
{
"title": "Apple Inc."
},
{
"title": "Apple Inc"
},
{
"title": "Apple Inc"
}
],
"related_searches": {
"images": [],
"text": [
{
"title": "apple inc full form",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Inc+full+form&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhPEAE"
},
{
"title": "apple company history",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+company+history&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhOEAE"
},
{
"title": "apple store",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Store&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhQEAE"
},
{
"title": "apple id",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+id&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhSEAE"
},
{
"title": "apple inc industry",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Inc+industry&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhREAE"
},
{
"title": "apple login",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+login&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhTEAE"
}
]
},
"image_results": [],
"carousel": [],
"total": 2450000000,
"knowledge_graph": "",
"related_questions": [
"What does the Apple Inc do?",
"Why did Apple change to Apple Inc?",
"Who owns Apple Inc.?",
"What is Apple Inc best known for?"
],
"carousel_count": 0,
"ts": 2.491065263748169,
"device_type": null
}
...@@ -20,7 +20,7 @@ from langchain.retrievers import ( ...@@ -20,7 +20,7 @@ from langchain.retrievers import (
from typing import Optional from typing import Optional
from utils.misc import get_last_user_message, add_or_update_system_message
from config import SRC_LOG_LEVELS, CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -236,10 +236,9 @@ def get_embedding_function( ...@@ -236,10 +236,9 @@ def get_embedding_function(
return lambda query: generate_multiple(query, func) return lambda query: generate_multiple(query, func)
def rag_messages( def get_rag_context(
docs, docs,
messages, messages,
template,
embedding_function, embedding_function,
k, k,
reranking_function, reranking_function,
...@@ -247,31 +246,7 @@ def rag_messages( ...@@ -247,31 +246,7 @@ def rag_messages(
hybrid_search, hybrid_search,
): ):
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}") log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
query = get_last_user_message(messages)
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i]["role"] == "user":
last_user_message_idx = i
break
user_message = messages[last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
extracted_collections = [] extracted_collections = []
relevant_contexts = [] relevant_contexts = []
...@@ -342,33 +317,7 @@ def rag_messages( ...@@ -342,33 +317,7 @@ def rag_messages(
context_string = context_string.strip() context_string = context_string.strip()
ra_content = rag_template( return context_string, citations
template=template,
context=context_string,
query=query,
)
log.debug(f"ra_content: {ra_content}")
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
messages[last_user_message_idx] = new_user_message
return messages, citations
def get_model_path(model: str, update_model: bool = False): def get_model_path(model: str, update_model: bool = False):
......
...@@ -10,7 +10,7 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") ...@@ -10,7 +10,7 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
# Dictionary to maintain the user pool # Dictionary to maintain the user pool
SESSION_POOL = {}
USER_POOL = {} USER_POOL = {}
USAGE_POOL = {} USAGE_POOL = {}
# Timeout duration in seconds # Timeout duration in seconds
...@@ -19,8 +19,6 @@ TIMEOUT_DURATION = 3 ...@@ -19,8 +19,6 @@ TIMEOUT_DURATION = 3
@sio.event @sio.event
async def connect(sid, environ, auth): async def connect(sid, environ, auth):
print("connect ", sid)
user = None user = None
if auth and "token" in auth: if auth and "token" in auth:
data = decode_token(auth["token"]) data = decode_token(auth["token"])
...@@ -29,10 +27,14 @@ async def connect(sid, environ, auth): ...@@ -29,10 +27,14 @@ async def connect(sid, environ, auth):
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
if user: if user:
USER_POOL[sid] = user.id SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}") print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(len(set(USER_POOL)))
await sio.emit("user-count", {"count": len(set(USER_POOL))}) await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("usage", {"models": get_models_in_use()}) await sio.emit("usage", {"models": get_models_in_use()})
...@@ -50,16 +52,20 @@ async def user_join(sid, data): ...@@ -50,16 +52,20 @@ async def user_join(sid, data):
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
if user: if user:
USER_POOL[sid] = user.id
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}") print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(len(set(USER_POOL)))
await sio.emit("user-count", {"count": len(set(USER_POOL))}) await sio.emit("user-count", {"count": len(set(USER_POOL))})
@sio.on("user-count") @sio.on("user-count")
async def user_count(sid): async def user_count(sid):
print("user-count", sid)
await sio.emit("user-count", {"count": len(set(USER_POOL))}) await sio.emit("user-count", {"count": len(set(USER_POOL))})
...@@ -68,14 +74,12 @@ def get_models_in_use(): ...@@ -68,14 +74,12 @@ def get_models_in_use():
models_in_use = [] models_in_use = []
for model_id, data in USAGE_POOL.items(): for model_id, data in USAGE_POOL.items():
models_in_use.append(model_id) models_in_use.append(model_id)
print(f"Models in use: {models_in_use}")
return models_in_use return models_in_use
@sio.on("usage") @sio.on("usage")
async def usage(sid, data): async def usage(sid, data):
print(f'Received "usage" event from {sid}: {data}')
model_id = data["model"] model_id = data["model"]
...@@ -103,7 +107,6 @@ async def usage(sid, data): ...@@ -103,7 +107,6 @@ async def usage(sid, data):
async def remove_after_timeout(sid, model_id): async def remove_after_timeout(sid, model_id):
try: try:
print("remove_after_timeout", sid, model_id)
await asyncio.sleep(TIMEOUT_DURATION) await asyncio.sleep(TIMEOUT_DURATION)
if model_id in USAGE_POOL: if model_id in USAGE_POOL:
print(USAGE_POOL[model_id]["sids"]) print(USAGE_POOL[model_id]["sids"])
...@@ -113,7 +116,6 @@ async def remove_after_timeout(sid, model_id): ...@@ -113,7 +116,6 @@ async def remove_after_timeout(sid, model_id):
if len(USAGE_POOL[model_id]["sids"]) == 0: if len(USAGE_POOL[model_id]["sids"]) == 0:
del USAGE_POOL[model_id] del USAGE_POOL[model_id]
print(f"Removed usage data for {model_id} due to timeout")
# Broadcast the usage data to all clients # Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()}) await sio.emit("usage", {"models": get_models_in_use()})
except asyncio.CancelledError: except asyncio.CancelledError:
...@@ -123,9 +125,14 @@ async def remove_after_timeout(sid, model_id): ...@@ -123,9 +125,14 @@ async def remove_after_timeout(sid, model_id):
@sio.event @sio.event
async def disconnect(sid): async def disconnect(sid):
if sid in USER_POOL: if sid in SESSION_POOL:
disconnected_user = USER_POOL.pop(sid) user_id = SESSION_POOL[sid]
print(f"user {disconnected_user} disconnected with session ID {sid}") del SESSION_POOL[sid]
USER_POOL[user_id].remove(sid)
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
await sio.emit("user-count", {"count": len(USER_POOL)}) await sio.emit("user-count", {"count": len(USER_POOL)})
else: else:
......
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Tool(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
content = pw.TextField()
specs = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "tool"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("tool")
"""Peewee migrations -- 011_add_user_oauth_sub.py. """Peewee migrations -- 013_add_user_oauth_sub.py.
Some examples (model - class or model name):: Some examples (model - class or model name)::
......
...@@ -8,6 +8,7 @@ from apps.webui.routers import ( ...@@ -8,6 +8,7 @@ from apps.webui.routers import (
users, users,
chats, chats,
documents, documents,
tools,
models, models,
prompts, prompts,
configs, configs,
...@@ -27,9 +28,9 @@ from config import ( ...@@ -27,9 +28,9 @@ from config import (
WEBHOOK_URL, WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN, JWT_EXPIRES_IN,
AppConfig,
ENABLE_COMMUNITY_SHARING,
WEBUI_BANNERS, WEBUI_BANNERS,
ENABLE_COMMUNITY_SHARING,
AppConfig,
) )
app = FastAPI() app = FastAPI()
...@@ -40,6 +41,7 @@ app.state.config = AppConfig() ...@@ -40,6 +41,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
...@@ -56,7 +58,7 @@ app.state.config.BANNERS = WEBUI_BANNERS ...@@ -56,7 +58,7 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {} app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.TOOLS = {}
app.add_middleware( app.add_middleware(
...@@ -72,6 +74,7 @@ app.include_router(users.router, prefix="/users", tags=["users"]) ...@@ -72,6 +74,7 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"]) app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(memories.router, prefix="/memories", tags=["memories"])
......
...@@ -65,6 +65,20 @@ class MemoriesTable: ...@@ -65,6 +65,20 @@ class MemoriesTable:
else: else:
return None return None
def update_memory_by_id(
self,
id: str,
content: str,
) -> Optional[MemoryModel]:
try:
memory = Memory.get(Memory.id == id)
memory.content = content
memory.updated_at = int(time.time())
memory.save()
return MemoryModel(**model_to_dict(memory))
except:
return None
def get_memories(self) -> List[MemoryModel]: def get_memories(self) -> List[MemoryModel]:
try: try:
memories = Memory.select() memories = Memory.select()
......
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tools DB Schema
####################
class Tool(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class ToolMeta(BaseModel):
description: Optional[str] = None
class ToolModel(BaseModel):
id: str
user_id: str
name: str
content: str
specs: List[dict]
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class ToolResponse(BaseModel):
id: str
user_id: str
name: str
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ToolForm(BaseModel):
id: str
name: str
content: str
meta: ToolMeta
class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool.create(**tool.model_dump())
if result:
return tool
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
**updated,
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def delete_tool_by_id(self, id: str) -> bool:
try:
query = Tool.delete().where((Tool.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Tools = ToolsTable(DB)
...@@ -161,7 +161,7 @@ async def get_archived_session_user_chat_list( ...@@ -161,7 +161,7 @@ async def get_archived_session_user_chat_list(
############################ ############################
@router.post("/archive/all", response_model=List[ChatTitleIdResponse]) @router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_current_user)): async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(user.id) return Chats.archive_all_chats_by_user_id(user.id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment