Unverified Commit 96a004d4 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #2921 from open-webui/dev

0.3.0
parents a8d80f93 1fa16d73
...@@ -5,6 +5,39 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,39 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.0] - 2024-06-09
### Added
- **📚 Knowledge Support for Models**: Attach documents directly to models from the models workspace, enhancing the information available to each model.
- **🎙️ Hands-Free Voice Call Feature**: Initiate voice calls without needing to use your hands, making interactions more seamless.
- **📹 Video Call Feature**: Enable video calls with supported vision models like Llava and GPT-4o, adding a visual dimension to your communications.
- **🎛️ Enhanced UI for Voice Recording**: Improved user interface for the voice recording feature, making it more intuitive and user-friendly.
- **🌐 External STT Support**: Now support for external Speech-To-Text services, providing more flexibility in choosing your STT provider.
- **⚙️ Unified Settings**: Consolidated settings including document settings under a new admin settings section for easier management.
- **🌑 Dark Mode Splash Screen**: A new splash screen for dark mode, ensuring a consistent and visually appealing experience for dark mode users.
- **📥 Upload Pipeline**: Directly upload pipelines from the admin settings > pipelines section, streamlining the pipeline management process.
- **🌍 Improved Language Support**: Enhanced support for Chinese and Ukrainian languages, better catering to a global user base.
### Fixed
- **🛠️ Playground Issue**: Fixed the playground not functioning properly, ensuring a smoother user experience.
- **🔥 Temperature Parameter Issue**: Corrected the issue where the temperature value '0' was not being passed correctly.
- **📝 Prompt Input Clearing**: Resolved prompt input textarea not being cleared right away, ensuring a clean slate for new inputs.
- **✨ Various UI Styling Issues**: Fixed numerous user interface styling problems for a more cohesive look.
- **👥 Active Users Display**: Fixed active users showing active sessions instead of actual users, now reflecting accurate user activity.
- **🌐 Community Platform Compatibility**: The Community Platform is back online and fully compatible with Open WebUI.
### Changed
- **📝 RAG Implementation**: Updated the RAG (Retrieval-Augmented Generation) implementation to use a system prompt for context, instead of overriding the user's prompt.
- **🔄 Settings Relocation**: Moved Models, Connections, Audio, and Images settings to the admin settings for better organization.
- **✍️ Improved Title Generation**: Enhanced the default prompt for title generation, yielding better results.
- **🔧 Backend Task Management**: Tasks like title generation and search query generation are now managed on the backend side and controlled only by the admin.
- **🔍 Editable Search Query Prompt**: You can now edit the search query generation prompt, offering more control over how queries are generated.
- **📏 Prompt Length Threshold**: Set the prompt length threshold for search query generation from the admin settings, giving more customization options.
- **📣 Settings Consolidation**: Merged the Banners admin setting with the Interface admin setting for a more streamlined settings area.
## [0.2.5] - 2024-06-05 ## [0.2.5] - 2024-06-05
### Added ### Added
......
...@@ -146,10 +146,19 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa ...@@ -146,10 +146,19 @@ docker run --rm --volume /var/run/docker.sock:/var/run/docker.sock containrrr/wa
In the last part of the command, replace `open-webui` with your container name if it is different. In the last part of the command, replace `open-webui` with your container name if it is different.
### Moving from Ollama WebUI to Open WebUI
Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/migration/). Check our Migration Guide available in our [Open WebUI Documentation](https://docs.openwebui.com/migration/).
### Using the Dev Branch 🌙
> [!WARNING]
> The `:dev` branch contains the latest unstable features and changes. Use it at your own risk as it may have bugs or incomplete features.
If you want to try out the latest bleeding-edge features and are okay with occasional instability, you can use the `:dev` tag like this:
```bash
docker run -d -p 3000:8080 -v open-webui:/app/backend/data --name open-webui --restart always ghcr.io/open-webui/open-webui:dev
```
## What's Next? 🌟 ## What's Next? 🌟
Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/). Discover upcoming features on our roadmap in the [Open WebUI Documentation](https://docs.openwebui.com/roadmap/).
......
...@@ -17,13 +17,12 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -17,13 +17,12 @@ from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from pydantic import BaseModel from pydantic import BaseModel
import uuid
import requests import requests
import hashlib import hashlib
from pathlib import Path from pathlib import Path
import json import json
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
decode_token, decode_token,
...@@ -41,10 +40,15 @@ from config import ( ...@@ -41,10 +40,15 @@ from config import (
WHISPER_MODEL_DIR, WHISPER_MODEL_DIR,
WHISPER_MODEL_AUTO_UPDATE, WHISPER_MODEL_AUTO_UPDATE,
DEVICE_TYPE, DEVICE_TYPE,
AUDIO_OPENAI_API_BASE_URL, AUDIO_STT_OPENAI_API_BASE_URL,
AUDIO_OPENAI_API_KEY, AUDIO_STT_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL, AUDIO_TTS_OPENAI_API_BASE_URL,
AUDIO_OPENAI_API_VOICE, AUDIO_TTS_OPENAI_API_KEY,
AUDIO_STT_ENGINE,
AUDIO_STT_MODEL,
AUDIO_TTS_ENGINE,
AUDIO_TTS_MODEL,
AUDIO_TTS_VOICE,
AppConfig, AppConfig,
) )
...@@ -61,10 +65,17 @@ app.add_middleware( ...@@ -61,10 +65,17 @@ app.add_middleware(
) )
app.state.config = AppConfig() app.state.config = AppConfig()
app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
# setting device type for whisper model # setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
...@@ -74,41 +85,101 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") ...@@ -74,41 +85,101 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
class OpenAIConfigUpdateForm(BaseModel): class TTSConfigForm(BaseModel):
url: str OPENAI_API_BASE_URL: str
key: str OPENAI_API_KEY: str
model: str ENGINE: str
speaker: str MODEL: str
VOICE: str
class STTConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
class AudioConfigUpdateForm(BaseModel):
tts: TTSConfigForm
stt: STTConfigForm
from pydub import AudioSegment
from pydub.utils import mediainfo
def is_mp4_audio(file_path):
"""Check if the given file is an MP4 audio file."""
if not os.path.isfile(file_path):
print(f"File not found: {file_path}")
return False
info = mediainfo(file_path)
if (
info.get("codec_name") == "aac"
and info.get("codec_type") == "audio"
and info.get("codec_tag_string") == "mp4a"
):
return True
return False
def convert_mp4_to_wav(file_path, output_path):
"""Convert MP4 audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format="mp4")
audio.export(output_path, format="wav")
print(f"Converted {file_path} to {output_path}")
@app.get("/config") @app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)): async def get_audio_config(user=Depends(get_admin_user)):
return { return {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, "tts": {
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, "ENGINE": app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
},
} }
@app.post("/config/update") @app.post("/config/update")
async def update_openai_config( async def update_audio_config(
form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user) form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.key == "": app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
app.state.config.TTS_ENGINE = form_data.tts.ENGINE
app.state.config.TTS_MODEL = form_data.tts.MODEL
app.state.config.TTS_VOICE = form_data.tts.VOICE
app.state.config.OPENAI_API_BASE_URL = form_data.url app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = form_data.key app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
app.state.config.OPENAI_API_MODEL = form_data.model app.state.config.STT_ENGINE = form_data.stt.ENGINE
app.state.config.OPENAI_API_VOICE = form_data.speaker app.state.config.STT_MODEL = form_data.stt.MODEL
return { return {
"status": True, "tts": {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL, "ENGINE": app.state.config.TTS_ENGINE,
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE, "MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
},
} }
...@@ -125,13 +196,21 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -125,13 +196,21 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
try:
body = body.decode("utf-8")
body = json.loads(body)
body["model"] = app.state.config.TTS_MODEL
body = json.dumps(body).encode("utf-8")
except Exception as e:
pass
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech", url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
data=body, data=body,
headers=headers, headers=headers,
stream=True, stream=True,
...@@ -181,41 +260,110 @@ def transcribe( ...@@ -181,41 +260,110 @@ def transcribe(
) )
try: try:
filename = file.filename ext = file.filename.split(".")[-1]
file_path = f"{UPLOAD_DIR}/{filename}"
id = uuid.uuid4()
filename = f"{id}.{ext}"
file_dir = f"{CACHE_DIR}/audio/transcriptions"
os.makedirs(file_dir, exist_ok=True)
file_path = f"{file_dir}/{filename}"
print(filename)
contents = file.file.read() contents = file.file.read()
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(contents) f.write(contents)
f.close() f.close()
whisper_kwargs = { if app.state.config.STT_ENGINE == "":
"model_size_or_path": WHISPER_MODEL, whisper_kwargs = {
"device": whisper_device_type, "model_size_or_path": WHISPER_MODEL,
"compute_type": "int8", "device": whisper_device_type,
"download_root": WHISPER_MODEL_DIR, "compute_type": "int8",
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE, "download_root": WHISPER_MODEL_DIR,
} "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
}
log.debug(f"whisper_kwargs: {whisper_kwargs}")
log.debug(f"whisper_kwargs: {whisper_kwargs}")
try:
model = WhisperModel(**whisper_kwargs) try:
except: model = WhisperModel(**whisper_kwargs)
log.warning( except:
"WhisperModel initialization failed, attempting download with local_files_only=False" log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
whisper_kwargs["local_files_only"] = False
model = WhisperModel(**whisper_kwargs)
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
) )
whisper_kwargs["local_files_only"] = False
model = WhisperModel(**whisper_kwargs)
segments, info = model.transcribe(file_path, beam_size=5) transcript = "".join([segment.text for segment in list(segments)])
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
transcript = "".join([segment.text for segment in list(segments)]) data = {"text": transcript.strip()}
return {"text": transcript.strip()} # save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
return data
elif app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
print("is_mp4_audio")
os.rename(file_path, file_path.replace(".wav", ".mp4"))
# Convert MP4 audio file to WAV format
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
files = {"file": (filename, open(file_path, "rb"))}
data = {"model": "whisper-1"}
print(files, data)
r = None
try:
r = requests.post(
url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
headers=headers,
files=files,
data=data,
)
r.raise_for_status()
data = r.json()
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r != None else 500,
detail=error_detail,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
......
...@@ -41,8 +41,6 @@ from utils.utils import ( ...@@ -41,8 +41,6 @@ from utils.utils import (
get_admin_user, get_admin_user,
) )
from utils.models import get_model_id_from_custom_model_id
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
...@@ -728,7 +726,6 @@ async def generate_chat_completion( ...@@ -728,7 +726,6 @@ async def generate_chat_completion(
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
print(model_info)
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
...@@ -764,7 +761,7 @@ async def generate_chat_completion( ...@@ -764,7 +761,7 @@ async def generate_chat_completion(
"frequency_penalty", None "frequency_penalty", None
) )
if model_info.params.get("temperature", None): if model_info.params.get("temperature", None) is not None:
payload["options"]["temperature"] = model_info.params.get( payload["options"]["temperature"] = model_info.params.get(
"temperature", None "temperature", None
) )
...@@ -849,9 +846,14 @@ async def generate_chat_completion( ...@@ -849,9 +846,14 @@ async def generate_chat_completion(
# TODO: we should update this part once Ollama supports other types # TODO: we should update this part once Ollama supports other types
class OpenAIChatMessageContent(BaseModel):
type: str
model_config = ConfigDict(extra="allow")
class OpenAIChatMessage(BaseModel): class OpenAIChatMessage(BaseModel):
role: str role: str
content: str content: Union[str, OpenAIChatMessageContent]
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra="allow")
...@@ -879,7 +881,6 @@ async def generate_openai_chat_completion( ...@@ -879,7 +881,6 @@ async def generate_openai_chat_completion(
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
print(model_info)
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
......
...@@ -345,108 +345,155 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use ...@@ -345,108 +345,155 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
) )
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.post("/chat/completions")
async def proxy(path: str, request: Request, user=Depends(get_verified_user)): @app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
idx = 0 idx = 0
payload = {**form_data}
body = await request.body() model_id = form_data.get("model")
# TODO: Remove below after gpt-4-vision fix from Open AI model_info = Models.get_model_by_id(model_id)
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
payload = None if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
try: model_info.params = model_info.params.model_dump()
if "chat/completions" in path:
body = body.decode("utf-8")
body = json.loads(body)
payload = {**body} if model_info.params:
if model_info.params.get("temperature", None) is not None:
payload["temperature"] = float(model_info.params.get("temperature"))
model_id = body.get("model") if model_info.params.get("top_p", None):
model_info = Models.get_model_by_id(model_id) payload["top_p"] = int(model_info.params.get("top_p", None))
if model_info: if model_info.params.get("max_tokens", None):
print(model_info) payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump() if model_info.params.get("frequency_penalty", None):
payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params: if model_info.params.get("system", None):
if model_info.params.get("temperature", None): # Check if the payload already has a system message
payload["temperature"] = int( # If not, add a system message to the payload
model_info.params.get("temperature") if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
) )
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if model_info.params.get("top_p", None): else:
payload["top_p"] = int(model_info.params.get("top_p", None)) pass
if model_info.params.get("max_tokens", None): model = app.state.MODELS[payload.get("model")]
payload["max_tokens"] = int( idx = model["urlIdx"]
model_info.params.get("max_tokens", None)
)
if model_info.params.get("frequency_penalty", None): if "pipeline" in model and model.get("pipeline"):
payload["frequency_penalty"] = int( payload["user"] = {"name": user.name, "id": user.id}
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None): # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
payload["seed"] = model_info.params.get("seed", None) # This is a workaround until OpenAI fixes the issue with this model
if payload.get("model") == "gpt-4-vision-preview":
if model_info.params.get("stop", None): if "max_tokens" not in payload:
payload["stop"] = ( payload["max_tokens"] = 4000
[ log.debug("Modified payload:", payload)
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None): # Convert the modified body back to JSON
# Check if the payload already has a system message payload = json.dumps(payload)
# If not, add a system message to the payload
if payload.get("messages"): print(payload)
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None)
+ message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
else:
pass
model = app.state.MODELS[payload.get("model")] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
idx = model["urlIdx"] print(payload)
if "pipeline" in model and model.get("pipeline"): headers = {}
payload["user"] = {"name": user.name, "id": user.id} headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 r = None
# This is a workaround until OpenAI fixes the issue with this model session = None
if payload.get("model") == "gpt-4-vision-preview": streaming = False
if "max_tokens" not in payload:
payload["max_tokens"] = 4000
log.debug("Modified payload:", payload)
# Convert the modified body back to JSON try:
payload = json.dumps(payload) session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
data=payload,
headers=headers,
)
except json.JSONDecodeError as e: r.raise_for_status()
log.error("Error loading request body into a dictionary:", e)
print(payload) # Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
if not streaming and session:
if r:
r.close()
await session.close()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0
body = await request.body()
url = app.state.config.OPENAI_API_BASE_URLS[idx] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx] key = app.state.config.OPENAI_API_KEYS[idx]
...@@ -466,7 +513,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -466,7 +513,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
r = await session.request( r = await session.request(
method=request.method, method=request.method,
url=target_url, url=target_url,
data=payload if payload else body, data=body,
headers=headers, headers=headers,
) )
......
...@@ -9,6 +9,7 @@ from fastapi import ( ...@@ -9,6 +9,7 @@ from fastapi import (
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import os, shutil, logging, re import os, shutil, logging, re
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import List, Union, Sequence from typing import List, Union, Sequence
...@@ -30,6 +31,7 @@ from langchain_community.document_loaders import ( ...@@ -30,6 +31,7 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader, UnstructuredExcelLoader,
UnstructuredPowerPointLoader, UnstructuredPowerPointLoader,
YoutubeLoader, YoutubeLoader,
OutlookMessageLoader,
) )
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
...@@ -879,6 +881,13 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -879,6 +881,13 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
# ChromaDB does not like datetime formats
# for meta-data so convert them to string.
for metadata in metadatas:
for key, value in metadata.items():
if isinstance(value, datetime):
metadata[key] = str(value)
try: try:
if overwrite: if overwrite:
for collection in CHROMA_CLIENT.list_collections(): for collection in CHROMA_CLIENT.list_collections():
...@@ -965,6 +974,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ...@@ -965,6 +974,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"swift", "swift",
"vue", "vue",
"svelte", "svelte",
"msg",
] ]
if file_ext == "pdf": if file_ext == "pdf":
...@@ -999,6 +1009,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ...@@ -999,6 +1009,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"application/vnd.openxmlformats-officedocument.presentationml.presentation", "application/vnd.openxmlformats-officedocument.presentationml.presentation",
] or file_ext in ["ppt", "pptx"]: ] or file_ext in ["ppt", "pptx"]:
loader = UnstructuredPowerPointLoader(file_path) loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
elif file_ext in known_source_ext or ( elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0 file_content_type and file_content_type.find("text/") >= 0
): ):
......
...@@ -20,7 +20,7 @@ from langchain.retrievers import ( ...@@ -20,7 +20,7 @@ from langchain.retrievers import (
from typing import Optional from typing import Optional
from utils.misc import get_last_user_message, add_or_update_system_message
from config import SRC_LOG_LEVELS, CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -247,31 +247,7 @@ def rag_messages( ...@@ -247,31 +247,7 @@ def rag_messages(
hybrid_search, hybrid_search,
): ):
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}") log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
query = get_last_user_message(messages)
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i]["role"] == "user":
last_user_message_idx = i
break
user_message = messages[last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
extracted_collections = [] extracted_collections = []
relevant_contexts = [] relevant_contexts = []
...@@ -349,24 +325,7 @@ def rag_messages( ...@@ -349,24 +325,7 @@ def rag_messages(
) )
log.debug(f"ra_content: {ra_content}") log.debug(f"ra_content: {ra_content}")
messages = add_or_update_system_message(ra_content, messages)
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
messages[last_user_message_idx] = new_user_message
return messages, citations return messages, citations
......
...@@ -10,7 +10,7 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io") ...@@ -10,7 +10,7 @@ app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
# Dictionary to maintain the user pool # Dictionary to maintain the user pool
SESSION_POOL = {}
USER_POOL = {} USER_POOL = {}
USAGE_POOL = {} USAGE_POOL = {}
# Timeout duration in seconds # Timeout duration in seconds
...@@ -29,7 +29,12 @@ async def connect(sid, environ, auth): ...@@ -29,7 +29,12 @@ async def connect(sid, environ, auth):
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
if user: if user:
USER_POOL[sid] = user.id SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}") print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(len(set(USER_POOL))) print(len(set(USER_POOL)))
...@@ -50,7 +55,13 @@ async def user_join(sid, data): ...@@ -50,7 +55,13 @@ async def user_join(sid, data):
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
if user: if user:
USER_POOL[sid] = user.id
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}") print(f"user {user.name}({user.id}) connected with session ID {sid}")
print(len(set(USER_POOL))) print(len(set(USER_POOL)))
...@@ -123,9 +134,17 @@ async def remove_after_timeout(sid, model_id): ...@@ -123,9 +134,17 @@ async def remove_after_timeout(sid, model_id):
@sio.event @sio.event
async def disconnect(sid): async def disconnect(sid):
if sid in USER_POOL: if sid in SESSION_POOL:
disconnected_user = USER_POOL.pop(sid) user_id = SESSION_POOL[sid]
print(f"user {disconnected_user} disconnected with session ID {sid}") del SESSION_POOL[sid]
USER_POOL[user_id].remove(sid)
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
print(f"user {user_id} disconnected with session ID {sid}")
print(USER_POOL)
await sio.emit("user-count", {"count": len(USER_POOL)}) await sio.emit("user-count", {"count": len(USER_POOL)})
else: else:
......
...@@ -306,7 +306,10 @@ STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve() ...@@ -306,7 +306,10 @@ STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png" frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png"
if frontend_favicon.exists(): if frontend_favicon.exists():
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") try:
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
except PermissionError:
logging.error(f"No write permission to {STATIC_DIR / 'favicon.png'}")
else: else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}") logging.warning(f"Frontend favicon not found at {frontend_favicon}")
...@@ -615,6 +618,66 @@ ADMIN_EMAIL = PersistentConfig( ...@@ -615,6 +618,66 @@ ADMIN_EMAIL = PersistentConfig(
) )
####################################
# TASKS
####################################
TASK_MODEL = PersistentConfig(
"TASK_MODEL",
"task.model.default",
os.environ.get("TASK_MODEL", ""),
)
TASK_MODEL_EXTERNAL = PersistentConfig(
"TASK_MODEL_EXTERNAL",
"task.model.external",
os.environ.get("TASK_MODEL_EXTERNAL", ""),
)
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"task.title.prompt_template",
os.environ.get(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"""Here is the query:
{{prompt:middletruncate:8000}}
Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights""",
),
)
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"task.search.prompt_template",
os.environ.get(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"""You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
Question:
{{prompt:end:4000}}""",
),
)
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
"task.search.prompt_length_threshold",
os.environ.get(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
100,
),
)
#################################### ####################################
# WEBUI_SECRET_KEY # WEBUI_SECRET_KEY
#################################### ####################################
...@@ -933,25 +996,59 @@ IMAGE_GENERATION_MODEL = PersistentConfig( ...@@ -933,25 +996,59 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
# Audio # Audio
#################################### ####################################
AUDIO_OPENAI_API_BASE_URL = PersistentConfig( AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_OPENAI_API_BASE_URL", "AUDIO_STT_OPENAI_API_BASE_URL",
"audio.openai.api_base_url", "audio.stt.openai.api_base_url",
os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_STT_OPENAI_API_KEY = PersistentConfig(
"AUDIO_STT_OPENAI_API_KEY",
"audio.stt.openai.api_key",
os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_STT_ENGINE = PersistentConfig(
"AUDIO_STT_ENGINE",
"audio.stt.engine",
os.getenv("AUDIO_STT_ENGINE", ""),
)
AUDIO_STT_MODEL = PersistentConfig(
"AUDIO_STT_MODEL",
"audio.stt.model",
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
)
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_TTS_OPENAI_API_BASE_URL",
"audio.tts.openai.api_base_url",
os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
) )
AUDIO_OPENAI_API_KEY = PersistentConfig( AUDIO_TTS_OPENAI_API_KEY = PersistentConfig(
"AUDIO_OPENAI_API_KEY", "AUDIO_TTS_OPENAI_API_KEY",
"audio.openai.api_key", "audio.tts.openai.api_key",
os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY), os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY),
) )
AUDIO_OPENAI_API_MODEL = PersistentConfig(
"AUDIO_OPENAI_API_MODEL",
"audio.openai.api_model", AUDIO_TTS_ENGINE = PersistentConfig(
os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"), "AUDIO_TTS_ENGINE",
"audio.tts.engine",
os.getenv("AUDIO_TTS_ENGINE", ""),
) )
AUDIO_OPENAI_API_VOICE = PersistentConfig(
"AUDIO_OPENAI_API_VOICE",
"audio.openai.api_voice", AUDIO_TTS_MODEL = PersistentConfig(
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"), "AUDIO_TTS_MODEL",
"audio.tts.model",
os.getenv("AUDIO_TTS_MODEL", "tts-1"),
)
AUDIO_TTS_VOICE = PersistentConfig(
"AUDIO_TTS_VOICE",
"audio.tts.voice",
os.getenv("AUDIO_TTS_VOICE", "alloy"),
) )
......
...@@ -9,8 +9,11 @@ import logging ...@@ -9,8 +9,11 @@ import logging
import aiohttp import aiohttp
import requests import requests
import mimetypes import mimetypes
import shutil
import os
import asyncio
from fastapi import FastAPI, Request, Depends, status from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi import HTTPException from fastapi import HTTPException
...@@ -22,15 +25,24 @@ from starlette.responses import StreamingResponse, Response ...@@ -22,15 +25,24 @@ from starlette.responses import StreamingResponse, Response
from apps.socket.main import app as socket_app from apps.socket.main import app as socket_app
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models from apps.ollama.main import (
from apps.openai.main import app as openai_app, get_all_models as get_openai_models app as ollama_app,
OpenAIChatCompletionForm,
get_all_models as get_ollama_models,
generate_openai_chat_completion as generate_ollama_chat_completion,
)
from apps.openai.main import (
app as openai_app,
get_all_models as get_openai_models,
generate_chat_completion as generate_openai_chat_completion,
)
from apps.audio.main import app as audio_app from apps.audio.main import app as audio_app
from apps.images.main import app as images_app from apps.images.main import app as images_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
from apps.webui.main import app as webui_app from apps.webui.main import app as webui_app
import asyncio
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional
...@@ -41,6 +53,8 @@ from utils.utils import ( ...@@ -41,6 +53,8 @@ from utils.utils import (
get_current_user, get_current_user,
get_http_authorization_cred, get_http_authorization_cred,
) )
from utils.task import title_generation_template, search_query_generation_template
from apps.rag.utils import rag_messages from apps.rag.utils import rag_messages
from config import ( from config import (
...@@ -62,8 +76,13 @@ from config import ( ...@@ -62,8 +76,13 @@ from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
WEBHOOK_URL, WEBHOOK_URL,
ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_EXPORT,
AppConfig,
WEBUI_BUILD_HASH, WEBUI_BUILD_HASH,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
AppConfig,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -117,10 +136,19 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API ...@@ -117,10 +136,19 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.TASK_MODEL = TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.MODELS = {} app.state.MODELS = {}
origins = ["*"] origins = ["*"]
...@@ -228,6 +256,78 @@ class RAGMiddleware(BaseHTTPMiddleware): ...@@ -228,6 +256,78 @@ class RAGMiddleware(BaseHTTPMiddleware):
app.add_middleware(RAGMiddleware) app.add_middleware(RAGMiddleware)
def filter_pipeline(payload, user):
user = {"id": user.id, "name": user.name, "role": user.role}
model_id = payload["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]
return payload
class PipelineMiddleware(BaseHTTPMiddleware): class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
if request.method == "POST" and ( if request.method == "POST" and (
...@@ -243,85 +343,10 @@ class PipelineMiddleware(BaseHTTPMiddleware): ...@@ -243,85 +343,10 @@ class PipelineMiddleware(BaseHTTPMiddleware):
# Parse string to JSON # Parse string to JSON
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
model_id = data["model"] user = get_current_user(
filters = [ get_http_authorization_cred(request.headers.get("Authorization"))
model )
for model in app.state.MODELS.values() data = filter_pipeline(data, user)
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
user = None
if len(sorted_filters) > 0:
try:
user = get_current_user(
get_http_authorization_cred(
request.headers.get("Authorization")
)
)
user = {"id": user.id, "name": user.name, "role": user.role}
except:
pass
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in data:
del data["chat_id"]
if "title" in data:
del data["title"]
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
...@@ -482,6 +507,178 @@ async def get_models(user=Depends(get_verified_user)): ...@@ -482,6 +507,178 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models} return {"data": models}
@app.get("/api/task/config")
async def get_task_config(user=Depends(get_verified_user)):
return {
"TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
}
class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
@app.post("/api/task/config/update")
async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
app.state.config.TASK_MODEL = form_data.TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
return {
"TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
}
@app.post("/api/task/title/completions")
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
print("generate_title")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
content = title_generation_template(
template, form_data["prompt"], user.model_dump()
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 50,
"chat_id": form_data.get("chat_id", None),
"title": True,
}
print(payload)
payload = filter_pipeline(payload, user)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/query/completions")
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
print("generate_search_query")
if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
)
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
content = search_query_generation_template(
template, form_data["prompt"], user.model_dump()
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 30,
}
print(payload)
payload = filter_pipeline(payload, user)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**form_data), user=user
)
else:
return await generate_openai_chat_completion(form_data, user=user)
@app.post("/api/chat/completed") @app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)): async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data data = form_data
...@@ -574,6 +771,63 @@ async def get_pipelines_list(user=Depends(get_admin_user)): ...@@ -574,6 +771,63 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
} }
@app.post("/api/pipelines/upload")
async def upload_pipeline(
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
):
print("upload_pipeline", urlIdx, file.filename)
# Check if the uploaded file is a python file
if not file.filename.endswith(".py"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only Python (.py) files are allowed.",
)
upload_folder = f"{CACHE_DIR}/pipelines"
os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename)
try:
# Save the uploaded file
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
headers = {"Authorization": f"Bearer {key}"}
with open(file_path, "rb") as f:
files = {"file": f}
r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
detail=detail,
)
finally:
# Ensure the file is deleted after the upload is completed or on failure
if os.path.exists(file_path):
os.remove(file_path)
class AddPipelineForm(BaseModel): class AddPipelineForm(BaseModel):
url: str url: str
urlIdx: int urlIdx: int
...@@ -840,6 +1094,15 @@ async def get_app_config(): ...@@ -840,6 +1094,15 @@ async def get_app_config():
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING, "enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_export": ENABLE_ADMIN_EXPORT,
}, },
"audio": {
"tts": {
"engine": audio_app.state.config.TTS_ENGINE,
"voice": audio_app.state.config.TTS_VOICE,
},
"stt": {
"engine": audio_app.state.config.STT_ENGINE,
},
},
} }
...@@ -902,7 +1165,7 @@ async def get_app_changelog(): ...@@ -902,7 +1165,7 @@ async def get_app_changelog():
@app.get("/api/version/updates") @app.get("/api/version/updates")
async def get_app_latest_release_version(): async def get_app_latest_release_version():
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get( async with session.get(
"https://api.github.com/repos/open-webui/open-webui/releases/latest" "https://api.github.com/repos/open-webui/open-webui/releases/latest"
) as response: ) as response:
......
...@@ -56,4 +56,7 @@ PyJWT[crypto]==2.8.0 ...@@ -56,4 +56,7 @@ PyJWT[crypto]==2.8.0
black==24.4.2 black==24.4.2
langfuse==2.33.0 langfuse==2.33.0
youtube-transcript-api==0.6.2 youtube-transcript-api==0.6.2
pytube==15.0.0 pytube==15.0.0
\ No newline at end of file
extract_msg
pydub
\ No newline at end of file
...@@ -20,12 +20,12 @@ if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then ...@@ -20,12 +20,12 @@ if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
WEBUI_SECRET_KEY=$(cat "$KEY_FILE") WEBUI_SECRET_KEY=$(cat "$KEY_FILE")
fi fi
if [ "$USE_OLLAMA_DOCKER" = "true" ]; then if [[ "${USE_OLLAMA_DOCKER,,}" == "true" ]]; then
echo "USE_OLLAMA is set to true, starting ollama serve." echo "USE_OLLAMA is set to true, starting ollama serve."
ollama serve & ollama serve &
fi fi
if [ "$USE_CUDA_DOCKER" = "true" ]; then if [[ "${USE_CUDA_DOCKER,,}" == "true" ]]; then
echo "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries." echo "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib" export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
fi fi
......
...@@ -3,7 +3,48 @@ import hashlib ...@@ -3,7 +3,48 @@ import hashlib
import json import json
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional from typing import Optional, List
def get_last_user_message(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "user":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return None
def get_last_assistant_message(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "assistant":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return None
def add_or_update_system_message(content: str, messages: List[dict]):
"""
Adds a new system message at the beginning of the messages list
or updates the existing system message at the beginning.
:param msg: The message to be added or appended.
:param messages: The list of message dictionaries.
:return: The updated list of message dictionaries.
"""
if messages and messages[0].get("role") == "system":
messages[0]["content"] += f"{content}\n{messages[0]['content']}"
else:
# Insert at the beginning
messages.insert(0, {"role": "system", "content": content})
return messages
def get_gravatar_url(email): def get_gravatar_url(email):
...@@ -193,8 +234,14 @@ def parse_ollama_modelfile(model_text): ...@@ -193,8 +234,14 @@ def parse_ollama_modelfile(model_text):
system_desc_match = re.search( system_desc_match = re.search(
r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
) )
system_desc_match_single = re.search(
r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
)
if system_desc_match: if system_desc_match:
data["params"]["system"] = system_desc_match.group(1).strip() data["params"]["system"] = system_desc_match.group(1).strip()
elif system_desc_match_single:
data["params"]["system"] = system_desc_match_single.group(1).strip()
# Parse messages # Parse messages
messages = [] messages = []
......
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
def get_model_id_from_custom_model_id(id: str):
model = Models.get_model_by_id(id)
if model:
return model.id
else:
return id
import re
import math
from datetime import datetime
from typing import Optional
def prompt_template(
template: str, user_name: str = None, current_location: str = None
) -> str:
# Get the current date
current_date = datetime.now()
# Format the date to YYYY-MM-DD
formatted_date = current_date.strftime("%Y-%m-%d")
# Replace {{CURRENT_DATE}} in the template with the formatted date
template = template.replace("{{CURRENT_DATE}}", formatted_date)
if user_name:
# Replace {{USER_NAME}} in the template with the user's name
template = template.replace("{{USER_NAME}}", user_name)
if current_location:
# Replace {{CURRENT_LOCATION}} in the template with the current location
template = template.replace("{{CURRENT_LOCATION}}", current_location)
return template
def title_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)
end_length = match.group(2)
middle_length = match.group(3)
if full_match == "{{prompt}}":
return prompt
elif start_length is not None:
return prompt[: int(start_length)]
elif end_length is not None:
return prompt[-int(end_length) :]
elif middle_length is not None:
middle_length = int(middle_length)
if len(prompt) <= middle_length:
return prompt
start = prompt[: math.ceil(middle_length / 2)]
end = prompt[-math.floor(middle_length / 2) :]
return f"{start}...{end}"
return ""
template = re.sub(
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
replacement_function,
template,
)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "current_location": user.get("location")}
if user
else {}
),
)
return template
def search_query_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)
end_length = match.group(2)
middle_length = match.group(3)
if full_match == "{{prompt}}":
return prompt
elif start_length is not None:
return prompt[: int(start_length)]
elif end_length is not None:
return prompt[-int(end_length) :]
elif middle_length is not None:
middle_length = int(middle_length)
if len(prompt) <= middle_length:
return prompt
start = prompt[: math.ceil(middle_length / 2)]
end = prompt[-math.floor(middle_length / 2) :]
return f"{start}...{end}"
return ""
template = re.sub(
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
replacement_function,
template,
)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "current_location": user.get("location")}
if user
else {}
),
)
return template
...@@ -41,7 +41,7 @@ Looking to contribute? Great! Here's how you can help: ...@@ -41,7 +41,7 @@ Looking to contribute? Great! Here's how you can help:
We welcome pull requests. Before submitting one, please: We welcome pull requests. Before submitting one, please:
1. Discuss your idea or issue in the [issues section](https://github.com/open-webui/open-webui/issues). 1. Open a discussion regarding your ideas [here](https://github.com/open-webui/open-webui/discussions/new/choose).
2. Follow the project's coding standards and include tests for new features. 2. Follow the project's coding standards and include tests for new features.
3. Update documentation as necessary. 3. Update documentation as necessary.
4. Write clear, descriptive commit messages. 4. Write clear, descriptive commit messages.
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.2.5", "version": "0.3.0",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "open-webui", "name": "open-webui",
"version": "0.2.5", "version": "0.3.0",
"dependencies": { "dependencies": {
"@pyscript/core": "^0.4.32", "@pyscript/core": "^0.4.32",
"@sveltejs/adapter-node": "^1.3.1", "@sveltejs/adapter-node": "^1.3.1",
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.2.5", "version": "0.3.0",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "npm run pyodide:fetch && vite dev --host", "dev": "npm run pyodide:fetch && vite dev --host",
......
...@@ -59,15 +59,7 @@ ...@@ -59,15 +59,7 @@
<div <div
id="splash-screen" id="splash-screen"
style=" style="position: fixed; z-index: 100; top: 0; left: 0; width: 100%; height: 100%"
position: fixed;
z-index: 100;
background: #fff;
top: 0;
left: 0;
width: 100%;
height: 100%;
"
> >
<style type="text/css" nonce=""> <style type="text/css" nonce="">
html { html {
...@@ -93,3 +85,20 @@ ...@@ -93,3 +85,20 @@
</div> </div>
</body> </body>
</html> </html>
<style type="text/css" nonce="">
html {
overflow-y: hidden !important;
}
#splash-screen {
background: #fff;
}
html.dark #splash-screen {
background: #000;
}
html.dark #splash-screen img {
filter: invert(1);
}
</style>
...@@ -98,7 +98,7 @@ export const synthesizeOpenAISpeech = async ( ...@@ -98,7 +98,7 @@ export const synthesizeOpenAISpeech = async (
token: string = '', token: string = '',
speaker: string = 'alloy', speaker: string = 'alloy',
text: string = '', text: string = '',
model: string = 'tts-1' model?: string
) => { ) => {
let error = null; let error = null;
...@@ -109,9 +109,9 @@ export const synthesizeOpenAISpeech = async ( ...@@ -109,9 +109,9 @@ export const synthesizeOpenAISpeech = async (
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}, },
body: JSON.stringify({ body: JSON.stringify({
model: model,
input: text, input: text,
voice: speaker voice: speaker,
...(model && { model })
}) })
}) })
.then(async (res) => { .then(async (res) => {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment