Unverified Commit a280fed4 authored by Simon's avatar Simon Committed by GitHub
Browse files

Merge branch 'open-webui:dev' into dev

parents aa8d2649 cf9b5241
...@@ -52,7 +52,6 @@ async def user_join(sid, data): ...@@ -52,7 +52,6 @@ async def user_join(sid, data):
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
if user: if user:
SESSION_POOL[sid] = user.id SESSION_POOL[sid] = user.id
if user.id in USER_POOL: if user.id in USER_POOL:
USER_POOL[user.id].append(sid) USER_POOL[user.id].append(sid)
...@@ -80,7 +79,6 @@ def get_models_in_use(): ...@@ -80,7 +79,6 @@ def get_models_in_use():
@sio.on("usage") @sio.on("usage")
async def usage(sid, data): async def usage(sid, data):
model_id = data["model"] model_id = data["model"]
# Cancel previous callback if there is one # Cancel previous callback if there is one
...@@ -139,7 +137,7 @@ async def disconnect(sid): ...@@ -139,7 +137,7 @@ async def disconnect(sid):
print(f"Unknown session ID {sid} disconnected") print(f"Unknown session ID {sid} disconnected")
async def get_event_emitter(request_info): def get_event_emitter(request_info):
async def __event_emitter__(event_data): async def __event_emitter__(event_data):
await sio.emit( await sio.emit(
"chat-events", "chat-events",
...@@ -154,7 +152,7 @@ async def get_event_emitter(request_info): ...@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
return __event_emitter__ return __event_emitter__
async def get_event_call(request_info): def get_event_call(request_info):
async def __event_call__(event_data): async def __event_call__(event_data):
response = await sio.call( response = await sio.call(
"chat-events", "chat-events",
......
This diff is collapsed.
import json
import logging import logging
from typing import Optional from typing import Optional, List
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger, Text from sqlalchemy import Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.internal.db import Base, JSONField, get_db
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
import time import time
...@@ -113,7 +111,6 @@ class ModelForm(BaseModel): ...@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
class ModelsTable: class ModelsTable:
def insert_new_model( def insert_new_model(
self, form_data: ModelForm, user_id: str self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]: ) -> Optional[ModelModel]:
...@@ -126,9 +123,7 @@ class ModelsTable: ...@@ -126,9 +123,7 @@ class ModelsTable:
} }
) )
try: try:
with get_db() as db: with get_db() as db:
result = Model(**model.model_dump()) result = Model(**model.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
...@@ -144,13 +139,11 @@ class ModelsTable: ...@@ -144,13 +139,11 @@ class ModelsTable:
def get_all_models(self) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:
with get_db() as db: with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()] return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
with get_db() as db: with get_db() as db:
model = db.get(Model, id) model = db.get(Model, id)
return ModelModel.model_validate(model) return ModelModel.model_validate(model)
except: except:
...@@ -178,7 +171,6 @@ class ModelsTable: ...@@ -178,7 +171,6 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
db.query(Model).filter_by(id=id).delete() db.query(Model).filter_by(id=id).delete()
db.commit() db.commit()
......
from pathlib import Path
import site
from fastapi import APIRouter, UploadFile, File, Response from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from starlette.responses import StreamingResponse, FileResponse from starlette.responses import StreamingResponse, FileResponse
...@@ -64,8 +67,18 @@ async def download_chat_as_pdf( ...@@ -64,8 +67,18 @@ async def download_chat_as_pdf(
pdf = FPDF() pdf = FPDF()
pdf.add_page() pdf.add_page()
STATIC_DIR = "./static" # When running in docker, workdir is /app/backend, so fonts is in /app/backend/static/fonts
FONTS_DIR = f"{STATIC_DIR}/fonts" FONTS_DIR = Path("./static/fonts")
# Non Docker Installation
# When running using `pip install` the static directory is in the site packages.
if not FONTS_DIR.exists():
FONTS_DIR = Path(site.getsitepackages()[0]) / "static/fonts"
# When running using `pip install -e .` the static directory is in the site packages.
# This path only works if `open-webui serve` is run from the root of this project.
if not FONTS_DIR.exists():
FONTS_DIR = Path("./backend/static/fonts")
pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf") pdf.add_font("NotoSans", "", f"{FONTS_DIR}/NotoSans-Regular.ttf")
pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf") pdf.add_font("NotoSans", "b", f"{FONTS_DIR}/NotoSans-Bold.ttf")
......
...@@ -349,6 +349,12 @@ GOOGLE_OAUTH_SCOPE = PersistentConfig( ...@@ -349,6 +349,12 @@ GOOGLE_OAUTH_SCOPE = PersistentConfig(
os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"), os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"),
) )
GOOGLE_REDIRECT_URI = PersistentConfig(
"GOOGLE_REDIRECT_URI",
"oauth.google.redirect_uri",
os.environ.get("GOOGLE_REDIRECT_URI", ""),
)
MICROSOFT_CLIENT_ID = PersistentConfig( MICROSOFT_CLIENT_ID = PersistentConfig(
"MICROSOFT_CLIENT_ID", "MICROSOFT_CLIENT_ID",
"oauth.microsoft.client_id", "oauth.microsoft.client_id",
...@@ -373,6 +379,12 @@ MICROSOFT_OAUTH_SCOPE = PersistentConfig( ...@@ -373,6 +379,12 @@ MICROSOFT_OAUTH_SCOPE = PersistentConfig(
os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"), os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"),
) )
MICROSOFT_REDIRECT_URI = PersistentConfig(
"MICROSOFT_REDIRECT_URI",
"oauth.microsoft.redirect_uri",
os.environ.get("MICROSOFT_REDIRECT_URI", ""),
)
OAUTH_CLIENT_ID = PersistentConfig( OAUTH_CLIENT_ID = PersistentConfig(
"OAUTH_CLIENT_ID", "OAUTH_CLIENT_ID",
"oauth.oidc.client_id", "oauth.oidc.client_id",
...@@ -391,6 +403,12 @@ OPENID_PROVIDER_URL = PersistentConfig( ...@@ -391,6 +403,12 @@ OPENID_PROVIDER_URL = PersistentConfig(
os.environ.get("OPENID_PROVIDER_URL", ""), os.environ.get("OPENID_PROVIDER_URL", ""),
) )
OPENID_REDIRECT_URI = PersistentConfig(
"OPENID_REDIRECT_URI",
"oauth.oidc.redirect_uri",
os.environ.get("OPENID_REDIRECT_URI", ""),
)
OAUTH_SCOPES = PersistentConfig( OAUTH_SCOPES = PersistentConfig(
"OAUTH_SCOPES", "OAUTH_SCOPES",
"oauth.oidc.scopes", "oauth.oidc.scopes",
...@@ -424,6 +442,7 @@ def load_oauth_providers(): ...@@ -424,6 +442,7 @@ def load_oauth_providers():
"client_secret": GOOGLE_CLIENT_SECRET.value, "client_secret": GOOGLE_CLIENT_SECRET.value,
"server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration", "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration",
"scope": GOOGLE_OAUTH_SCOPE.value, "scope": GOOGLE_OAUTH_SCOPE.value,
"redirect_uri": GOOGLE_REDIRECT_URI.value,
} }
if ( if (
...@@ -436,6 +455,7 @@ def load_oauth_providers(): ...@@ -436,6 +455,7 @@ def load_oauth_providers():
"client_secret": MICROSOFT_CLIENT_SECRET.value, "client_secret": MICROSOFT_CLIENT_SECRET.value,
"server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration", "server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration",
"scope": MICROSOFT_OAUTH_SCOPE.value, "scope": MICROSOFT_OAUTH_SCOPE.value,
"redirect_uri": MICROSOFT_REDIRECT_URI.value,
} }
if ( if (
...@@ -449,6 +469,7 @@ def load_oauth_providers(): ...@@ -449,6 +469,7 @@ def load_oauth_providers():
"server_metadata_url": OPENID_PROVIDER_URL.value, "server_metadata_url": OPENID_PROVIDER_URL.value,
"scope": OAUTH_SCOPES.value, "scope": OAUTH_SCOPES.value,
"name": OAUTH_PROVIDER_NAME.value, "name": OAUTH_PROVIDER_NAME.value,
"redirect_uri": OPENID_REDIRECT_URI.value,
} }
......
...@@ -13,8 +13,6 @@ import aiohttp ...@@ -13,8 +13,6 @@ import aiohttp
import requests import requests
import mimetypes import mimetypes
import shutil import shutil
import os
import uuid
import inspect import inspect
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
...@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware ...@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call from apps.socket.main import app as socket_app, get_event_emitter, get_event_call
from apps.ollama.main import ( from apps.ollama.main import (
app as ollama_app, app as ollama_app,
get_all_models as get_ollama_models, get_all_models as get_ollama_models,
...@@ -619,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -619,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content={"detail": str(e)}, content={"detail": str(e)},
) )
# Extract valves from the request body metadata = {
valves = None "chat_id": body.pop("chat_id", None),
if "valves" in body: "message_id": body.pop("id", None),
valves = body["valves"] "session_id": body.pop("session_id", None),
del body["valves"] "valves": body.pop("valves", None),
}
# Extract session_id, chat_id and message_id from the request body
session_id = None __event_emitter__ = get_event_emitter(metadata)
if "session_id" in body: __event_call__ = get_event_call(metadata)
session_id = body["session_id"]
del body["session_id"]
chat_id = None
if "chat_id" in body:
chat_id = body["chat_id"]
del body["chat_id"]
message_id = None
if "id" in body:
message_id = body["id"]
del body["id"]
__event_emitter__ = await get_event_emitter(
{"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
)
__event_call__ = await get_event_call(
{"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
)
# Initialize data_items to store additional data to be sent to the client # Initialize data_items to store additional data to be sent to the client
data_items = [] data_items = []
...@@ -709,13 +690,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -709,13 +690,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if len(citations) > 0: if len(citations) > 0:
data_items.append({"citations": citations}) data_items.append({"citations": citations})
body["metadata"] = { body["metadata"] = metadata
"session_id": session_id,
"chat_id": chat_id,
"message_id": message_id,
"valves": valves,
}
modified_body_bytes = json.dumps(body).encode("utf-8") modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
request._body = modified_body_bytes request._body = modified_body_bytes
...@@ -1191,13 +1166,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1191,13 +1166,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
status_code=r.status_code, status_code=r.status_code,
content=res, content=res,
) )
except: except Exception:
pass pass
else: else:
pass pass
__event_emitter__ = await get_event_emitter( __event_emitter__ = get_event_emitter(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
...@@ -1205,7 +1180,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1205,7 +1180,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
} }
) )
__event_call__ = await get_event_call( __event_call__ = get_event_call(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
...@@ -1334,14 +1309,14 @@ async def chat_completed( ...@@ -1334,14 +1309,14 @@ async def chat_completed(
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
__event_emitter__ = await get_event_emitter( __event_emitter__ = get_event_emitter(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
"session_id": data["session_id"], "session_id": data["session_id"],
} }
) )
__event_call__ = await get_event_call( __event_call__ = get_event_call(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
...@@ -1770,7 +1745,6 @@ class AddPipelineForm(BaseModel): ...@@ -1770,7 +1745,6 @@ class AddPipelineForm(BaseModel):
@app.post("/api/pipelines/add") @app.post("/api/pipelines/add")
async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):
r = None r = None
try: try:
urlIdx = form_data.urlIdx urlIdx = form_data.urlIdx
...@@ -1813,7 +1787,6 @@ class DeletePipelineForm(BaseModel): ...@@ -1813,7 +1787,6 @@ class DeletePipelineForm(BaseModel):
@app.delete("/api/pipelines/delete") @app.delete("/api/pipelines/delete")
async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):
r = None r = None
try: try:
urlIdx = form_data.urlIdx urlIdx = form_data.urlIdx
...@@ -1891,7 +1864,6 @@ async def get_pipeline_valves( ...@@ -1891,7 +1864,6 @@ async def get_pipeline_valves(
models = await get_all_models() models = await get_all_models()
r = None r = None
try: try:
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
...@@ -2143,6 +2115,7 @@ for provider_name, provider_config in OAUTH_PROVIDERS.items(): ...@@ -2143,6 +2115,7 @@ for provider_name, provider_config in OAUTH_PROVIDERS.items():
client_kwargs={ client_kwargs={
"scope": provider_config["scope"], "scope": provider_config["scope"],
}, },
redirect_uri=provider_config["redirect_uri"],
) )
# SessionMiddleware is used by authlib for oauth # SessionMiddleware is used by authlib for oauth
...@@ -2160,7 +2133,10 @@ if len(OAUTH_PROVIDERS) > 0: ...@@ -2160,7 +2133,10 @@ if len(OAUTH_PROVIDERS) > 0:
async def oauth_login(provider: str, request: Request): async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS: if provider not in OAUTH_PROVIDERS:
raise HTTPException(404) raise HTTPException(404)
redirect_uri = request.url_for("oauth_callback", provider=provider) # If the provider has a custom redirect URL, use that, otherwise automatically generate one
redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for(
"oauth_callback", provider=provider
)
return await oauth.create_client(provider).authorize_redirect(request, redirect_uri) return await oauth.create_client(provider).authorize_redirect(request, redirect_uri)
......
from pathlib import Path from pathlib import Path
import hashlib import hashlib
import json
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
...@@ -8,37 +7,39 @@ import uuid ...@@ -8,37 +7,39 @@ import uuid
import time import time
def get_last_user_message_item(messages: List[dict]) -> str: def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
for message in reversed(messages): for message in reversed(messages):
if message["role"] == "user": if message["role"] == "user":
return message return message
return None return None
def get_last_user_message(messages: List[dict]) -> str: def get_content_from_message(message: dict) -> Optional[str]:
message = get_last_user_message_item(messages) if isinstance(message["content"], list):
for item in message["content"]:
if message is not None: if item["type"] == "text":
if isinstance(message["content"], list): return item["text"]
for item in message["content"]: else:
if item["type"] == "text":
return item["text"]
return message["content"] return message["content"]
return None return None
def get_last_assistant_message(messages: List[dict]) -> str: def get_last_user_message(messages: List[dict]) -> Optional[str]:
message = get_last_user_message_item(messages)
if message is None:
return None
return get_content_from_message(message)
def get_last_assistant_message(messages: List[dict]) -> Optional[str]:
for message in reversed(messages): for message in reversed(messages):
if message["role"] == "assistant": if message["role"] == "assistant":
if isinstance(message["content"], list): return get_content_from_message(message)
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return None return None
def get_system_message(messages: List[dict]) -> dict: def get_system_message(messages: List[dict]) -> Optional[dict]:
for message in messages: for message in messages:
if message["role"] == "system": if message["role"] == "system":
return message return message
...@@ -49,7 +50,7 @@ def remove_system_message(messages: List[dict]) -> List[dict]: ...@@ -49,7 +50,7 @@ def remove_system_message(messages: List[dict]) -> List[dict]:
return [message for message in messages if message["role"] != "system"] return [message for message in messages if message["role"] != "system"]
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]: def pop_system_message(messages: List[dict]) -> Tuple[Optional[dict], List[dict]]:
return get_system_message(messages), remove_system_message(messages) return get_system_message(messages), remove_system_message(messages)
...@@ -87,23 +88,29 @@ def add_or_update_system_message(content: str, messages: List[dict]): ...@@ -87,23 +88,29 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages return messages
def stream_message_template(model: str, message: str): def openai_chat_message_template(model: str):
return { return {
"id": f"{model}-{str(uuid.uuid4())}", "id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()), "created": int(time.time()),
"model": model, "model": model,
"choices": [ "choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
} }
def openai_chat_chunk_message_template(model: str, message: str):
template = openai_chat_message_template(model)
template["object"] = "chat.completion.chunk"
template["choices"][0]["delta"] = {"content": message}
return template
def openai_chat_completion_message_template(model: str, message: str):
template = openai_chat_message_template(model)
template["object"] = "chat.completion"
template["choices"][0]["message"] = {"content": message, "role": "assistant"}
template["choices"][0]["finish_reason"] = "stop"
def get_gravatar_url(email): def get_gravatar_url(email):
# Trim leading and trailing whitespace from # Trim leading and trailing whitespace from
# an email address and force all characters # an email address and force all characters
...@@ -174,7 +181,7 @@ def extract_folders_after_data_docs(path): ...@@ -174,7 +181,7 @@ def extract_folders_after_data_docs(path):
tags = [] tags = []
folders = parts[index_docs:-1] folders = parts[index_docs:-1]
for idx, part in enumerate(folders): for idx, _ in enumerate(folders):
tags.append("/".join(folders[: idx + 1])) tags.append("/".join(folders[: idx + 1]))
return tags return tags
...@@ -270,11 +277,11 @@ def parse_ollama_modelfile(model_text): ...@@ -270,11 +277,11 @@ def parse_ollama_modelfile(model_text):
value = param_match.group(1) value = param_match.group(1)
try: try:
if param_type == int: if param_type is int:
value = int(value) value = int(value)
elif param_type == float: elif param_type is float:
value = float(value) value = float(value)
elif param_type == bool: elif param_type is bool:
value = value.lower() == "true" value = value.lower() == "true"
except Exception as e: except Exception as e:
print(e) print(e)
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
"dayjs": "^1.11.10", "dayjs": "^1.11.10",
"eventsource-parser": "^1.1.2", "eventsource-parser": "^1.1.2",
"file-saver": "^2.0.5", "file-saver": "^2.0.5",
"fuse.js": "^7.0.0",
"highlight.js": "^11.9.0", "highlight.js": "^11.9.0",
"i18next": "^23.10.0", "i18next": "^23.10.0",
"i18next-browser-languagedetector": "^7.2.0", "i18next-browser-languagedetector": "^7.2.0",
...@@ -4820,6 +4821,14 @@ ...@@ -4820,6 +4821,14 @@
"url": "https://github.com/sponsors/ljharb" "url": "https://github.com/sponsors/ljharb"
} }
}, },
"node_modules/fuse.js": {
"version": "7.0.0",
"resolved": "https://registry.npmjs.org/fuse.js/-/fuse.js-7.0.0.tgz",
"integrity": "sha512-14F4hBIxqKvD4Zz/XjDc3y94mNZN6pRv3U13Udo0lNLCWRBUsrMv2xwcF/y/Z5sV6+FQW+/ow68cHpm4sunt8Q==",
"engines": {
"node": ">=10"
}
},
"node_modules/gc-hook": { "node_modules/gc-hook": {
"version": "0.3.1", "version": "0.3.1",
"resolved": "https://registry.npmjs.org/gc-hook/-/gc-hook-0.3.1.tgz", "resolved": "https://registry.npmjs.org/gc-hook/-/gc-hook-0.3.1.tgz",
......
...@@ -154,3 +154,7 @@ input[type='number'] { ...@@ -154,3 +154,7 @@ input[type='number'] {
.tippy-box[data-theme~='dark'] { .tippy-box[data-theme~='dark'] {
@apply rounded-lg bg-gray-950 text-xs border border-gray-900 shadow-xl; @apply rounded-lg bg-gray-950 text-xs border border-gray-900 shadow-xl;
} }
.password {
-webkit-text-security: disc;
}
...@@ -98,6 +98,7 @@ ...@@ -98,6 +98,7 @@
const uploadFileHandler = async (file) => { const uploadFileHandler = async (file) => {
console.log(file); console.log(file);
// Check if the file is an audio file and transcribe/convert it to text file // Check if the file is an audio file and transcribe/convert it to text file
if (['audio/mpeg', 'audio/wav'].includes(file['type'])) { if (['audio/mpeg', 'audio/wav'].includes(file['type'])) {
const res = await transcribeAudio(localStorage.token, file).catch((error) => { const res = await transcribeAudio(localStorage.token, file).catch((error) => {
...@@ -112,40 +113,49 @@ ...@@ -112,40 +113,49 @@
} }
} }
// Upload the file to the server const fileItem = {
const uploadedFile = await uploadFile(localStorage.token, file).catch((error) => { type: 'file',
toast.error(error); file: '',
return null; id: null,
}); url: '',
name: file.name,
if (uploadedFile) { collection_name: '',
const fileItem = { status: '',
type: 'file', size: file.size,
file: uploadedFile, error: ''
id: uploadedFile.id, };
url: `${WEBUI_API_BASE_URL}/files/${uploadedFile.id}`, files = [...files, fileItem];
name: file.name,
collection_name: '', try {
status: 'uploaded', const uploadedFile = await uploadFile(localStorage.token, file);
error: ''
}; if (uploadedFile) {
files = [...files, fileItem]; fileItem.status = 'uploaded';
fileItem.file = uploadedFile;
// TODO: Check if tools & functions have files support to skip this step to delegate file processing fileItem.id = uploadedFile.id;
// Default Upload to VectorDB fileItem.url = `${WEBUI_API_BASE_URL}/files/${uploadedFile.id}`;
if (
SUPPORTED_FILE_TYPE.includes(file['type']) || // TODO: Check if tools & functions have files support to skip this step to delegate file processing
SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1)) // Default Upload to VectorDB
) { if (
processFileItem(fileItem); SUPPORTED_FILE_TYPE.includes(file['type']) ||
SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1))
) {
processFileItem(fileItem);
} else {
toast.error(
$i18n.t(`Unknown file type '{{file_type}}'. Proceeding with the file upload anyway.`, {
file_type: file['type']
})
);
processFileItem(fileItem);
}
} else { } else {
toast.error( files = files.filter((item) => item.status !== null);
$i18n.t(`Unknown file type '{{file_type}}'. Proceeding with the file upload anyway.`, {
file_type: file['type']
})
);
processFileItem(fileItem);
} }
} catch (e) {
toast.error(e);
files = files.filter((item) => item.status !== null);
} }
}; };
...@@ -162,7 +172,6 @@ ...@@ -162,7 +172,6 @@
// Remove the failed doc from the files array // Remove the failed doc from the files array
// files = files.filter((f) => f.id !== fileItem.id); // files = files.filter((f) => f.id !== fileItem.id);
toast.error(e); toast.error(e);
fileItem.status = 'processed'; fileItem.status = 'processed';
files = files; files = files;
} }
......
<script lang="ts"> <script lang="ts">
import { DropdownMenu } from 'bits-ui'; import { DropdownMenu } from 'bits-ui';
import { marked } from 'marked'; import { marked } from 'marked';
import Fuse from 'fuse.js';
import { flyAndScale } from '$lib/utils/transitions'; import { flyAndScale } from '$lib/utils/transitions';
import { createEventDispatcher, onMount, getContext, tick } from 'svelte'; import { createEventDispatcher, onMount, getContext, tick } from 'svelte';
...@@ -45,17 +46,29 @@ ...@@ -45,17 +46,29 @@
let selectedModelIdx = 0; let selectedModelIdx = 0;
$: filteredItems = items.filter( const fuse = new Fuse(
(item) => items
(searchValue .filter((item) => !item.model?.info?.meta?.hidden)
? item.value.toLowerCase().includes(searchValue.toLowerCase()) || .map((item) => {
item.label.toLowerCase().includes(searchValue.toLowerCase()) || const _item = {
(item.model?.info?.meta?.tags ?? []).some((tag) => ...item,
tag.name.toLowerCase().includes(searchValue.toLowerCase()) modelName: item.model?.name,
) tags: item.model?.info?.meta?.tags?.map((tag) => tag.name).join(' '),
: true) && !(item.model?.info?.meta?.hidden ?? false) desc: item.model?.info?.meta?.description
};
return _item;
}),
{
keys: ['value', 'label', 'tags', 'desc', 'modelName']
}
); );
$: filteredItems = searchValue
? fuse.search(searchValue).map((e) => {
return e.item;
})
: items.filter((item) => !item.model?.info?.meta?.hidden);
const pullModelHandler = async () => { const pullModelHandler = async () => {
const sanitizedModelTag = searchValue.trim().replace(/^ollama\s+(run|pull)\s+/, ''); const sanitizedModelTag = searchValue.trim().replace(/^ollama\s+(run|pull)\s+/, '');
......
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
<div class={outerClassName}> <div class={outerClassName}>
<input <input
class={inputClassName} class={`${inputClassName} ${show ? '' : 'password'}`}
{placeholder} {placeholder}
bind:value bind:value
required={required && !readOnly} required={required && !readOnly}
disabled={readOnly} disabled={readOnly}
autocomplete="off" autocomplete="off"
{...{ type: show ? 'text' : 'password' }} type="text"
/> />
<button <button
class={showButtonClassName} class={showButtonClassName}
......
...@@ -111,6 +111,10 @@ ...@@ -111,6 +111,10 @@
"code": "pt-PT", "code": "pt-PT",
"title": "Portuguese (Portugal)" "title": "Portuguese (Portugal)"
}, },
{
"code": "ro-RO",
"title": "Romanian (Romania)"
},
{ {
"code": "ru-RU", "code": "ru-RU",
"title": "Russian (Russia)" "title": "Russian (Russia)"
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment