Commit f5487628 authored by Morgan Blangeois's avatar Morgan Blangeois
Browse files

Resolve merge conflicts in French translations

parents 2fedd91e 2c061777
......@@ -10,7 +10,8 @@ node_modules
vite.config.js.timestamp-*
vite.config.ts.timestamp-*
__pycache__
.env
.idea
venv
_old
uploads
.ipynb_checkpoints
......
......@@ -306,3 +306,4 @@ dist
# cypress artifacts
cypress/videos
cypress/screenshots
.vscode/settings.json
......@@ -14,7 +14,6 @@ from fastapi import (
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from pydantic import BaseModel
import uuid
......@@ -277,6 +276,8 @@ def transcribe(
f.close()
if app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type,
......
......@@ -12,7 +12,6 @@ from fastapi import (
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
......
......@@ -25,6 +25,7 @@ from utils.task import prompt_template
from config import (
SRC_LOG_LEVELS,
ENABLE_OPENAI_API,
AIOHTTP_CLIENT_TIMEOUT,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
CACHE_DIR,
......@@ -463,7 +464,9 @@ async def generate_chat_completion(
streaming = False
try:
session = aiohttp.ClientSession(trust_env=True)
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
......
......@@ -48,8 +48,6 @@ import mimetypes
import uuid
import json
import sentence_transformers
from apps.webui.models.documents import (
Documents,
DocumentForm,
......@@ -93,6 +91,8 @@ from config import (
SRC_LOG_LEVELS,
UPLOAD_DIR,
DOCS_DIR,
CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL,
RAG_TOP_K,
RAG_RELEVANCE_THRESHOLD,
RAG_EMBEDDING_ENGINE,
......@@ -148,6 +148,9 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
......@@ -190,6 +193,8 @@ def update_embedding_model(
update_model: bool = False,
):
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
import sentence_transformers
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
......@@ -204,6 +209,8 @@ def update_reranking_model(
update_model: bool = False,
):
if reranking_model:
import sentence_transformers
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model),
device=DEVICE_TYPE,
......@@ -388,6 +395,10 @@ async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
......@@ -417,6 +428,11 @@ async def get_rag_config(user=Depends(get_admin_user)):
}
class ContentExtractionConfig(BaseModel):
engine: str = ""
tika_server_url: Optional[str] = None
class ChunkParamUpdateForm(BaseModel):
chunk_size: int
chunk_overlap: int
......@@ -450,6 +466,7 @@ class WebConfig(BaseModel):
class ConfigUpdateForm(BaseModel):
pdf_extract_images: Optional[bool] = None
content_extraction: Optional[ContentExtractionConfig] = None
chunk: Optional[ChunkParamUpdateForm] = None
youtube: Optional[YoutubeLoaderConfig] = None
web: Optional[WebConfig] = None
......@@ -463,6 +480,11 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
else app.state.config.PDF_EXTRACT_IMAGES
)
if form_data.content_extraction is not None:
log.info(f"Updating text settings: {form_data.content_extraction}")
app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine
app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
if form_data.chunk is not None:
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
......@@ -499,6 +521,10 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
return {
"status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
......@@ -985,6 +1011,41 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
return False
class TikaLoader:
def __init__(self, file_path, mime_type=None):
self.file_path = file_path
self.mime_type = mime_type
def load(self) -> List[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
if self.mime_type is not None:
headers = {"Content-Type": self.mime_type}
else:
headers = {}
endpoint = app.state.config.TIKA_SERVER_URL
if not endpoint.endswith("/"):
endpoint += "/"
endpoint += "tika/text"
r = requests.put(endpoint, data=data, headers=headers)
if r.ok:
raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
log.info("Tika extracted text: %s", text)
return [Document(page_content=text, metadata=headers)]
else:
raise Exception(f"Error calling Tika: {r.reason}")
def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
known_type = True
......@@ -1035,6 +1096,17 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"msg",
]
if (
app.state.config.CONTENT_EXTRACTION_ENGINE == "tika"
and app.state.config.TIKA_SERVER_URL
):
if file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TikaLoader(file_path, file_content_type)
else:
if file_ext == "pdf":
loader = PyPDFLoader(
file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
......
......@@ -294,15 +294,17 @@ def get_rag_context(
extracted_collections.extend(collection_names)
context_string = ""
contexts = []
citations = []
for context in relevant_contexts:
try:
if "documents" in context:
context_string += "\n\n".join(
contexts.append(
"\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
)
if "metadatas" in context:
citations.append(
......@@ -315,9 +317,7 @@ def get_rag_context(
except Exception as e:
log.exception(e)
context_string = context_string.strip()
return context_string, citations
return contexts, citations
def get_model_path(model: str, update_model: bool = False):
......@@ -442,8 +442,6 @@ from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra
from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor):
embedding_function: Any
......@@ -468,6 +466,8 @@ class RerankCompressor(BaseDocumentCompressor):
[(query, doc.page_content) for doc in documents]
)
else:
from sentence_transformers import util
query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function(
[doc.page_content for doc in documents]
......
......@@ -259,6 +259,9 @@ async def generate_function_chat_completion(form_data, user):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except:
......
......@@ -214,8 +214,7 @@ class FunctionsTable:
user_settings["functions"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
Users.update_user_by_id(user_id, {"settings": user_settings})
return user_settings["functions"]["valves"][id]
except Exception as e:
......
......@@ -170,8 +170,7 @@ class ToolsTable:
user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
Users.update_user_by_id(user_id, {"settings": user_settings})
return user_settings["tools"]["valves"][id]
except Exception as e:
......
......@@ -5,9 +5,8 @@ import importlib.metadata
import pkgutil
import chromadb
from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup
from typing import TypeVar, Generic, Union
from typing import TypeVar, Generic
from pydantic import BaseModel
from typing import Optional
......@@ -19,7 +18,6 @@ import markdown
import requests
import shutil
from secrets import token_bytes
from constants import ERROR_MESSAGES
####################################
......@@ -768,12 +766,14 @@ class BannerModel(BaseModel):
dismissible: bool
timestamp: int
try:
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
banners = [BannerModel(**banner) for banner in banners]
except Exception as e:
print(f"Error loading WEBUI_BANNERS: {e}")
banners = []
WEBUI_BANNERS = PersistentConfig(
"WEBUI_BANNERS",
"ui.banners",
[BannerModel(**banner) for banner in json.loads("[]")],
)
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
SHOW_ADMIN_DETAILS = PersistentConfig(
......@@ -885,6 +885,22 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
####################################
# RAG document content extraction
####################################
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
"rag.CONTENT_EXTRACTION_ENGINE",
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
)
TIKA_SERVER_URL = PersistentConfig(
"TIKA_SERVER_URL",
"rag.tika_server_url",
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
)
####################################
# RAG
####################################
......
......@@ -33,7 +33,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import app as socket_app
from apps.socket.main import sio, app as socket_app
from apps.ollama.main import (
app as ollama_app,
OpenAIChatCompletionForm,
......@@ -212,8 +212,79 @@ origins = ["*"]
##################################
async def get_body_and_model_and_user(request):
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {}
model_id = body["model"]
if model_id not in app.state.MODELS:
raise "Model not found"
model = app.state.MODELS[model_id]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
return body, model, user
def get_task_model_id(default_model_id):
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
return task_model_id
def get_filter_function_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
return filter_ids
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user
messages,
files,
tool_id,
template,
task_model_id,
user,
model,
__event_emitter__=None,
):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
......@@ -350,6 +421,13 @@ async def get_function_call_response(
"__id__": tool_id,
}
if "__event_emitter__" in sig.parameters:
# Call the function with the '__event_emitter__' parameter included
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
......@@ -373,68 +451,10 @@ async def get_function_call_response(
return None, None, False
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
data_items = []
show_citations = False
citations = []
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
):
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
data = json.loads(body_str) if body_str else {}
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
if data.get("citations"):
show_citations = True
del data["citations"]
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get(
"priority", 0
)
return 0
filter_ids = [
function.id for function in Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type(
"filter", active_only=True
)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
async def chat_completion_functions_handler(body, model, user, __event_emitter__):
skip_files = None
filter_ids.sort(key=get_priority)
filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
......@@ -464,7 +484,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": data}
params = {"body": body}
if "__user__" in sig.parameters:
__user__ = {
......@@ -492,108 +512,206 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
"__id__": filter_id,
}
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if inspect.iscoroutinefunction(inlet):
data = await inlet(**params)
body = await inlet(**params)
else:
data = inlet(**params)
body = inlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
raise e
if skip_files:
if "files" in body:
del body["files"]
return body, {}
# Set the task model
task_model_id = data["model"]
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"])
context = ""
async def chat_completion_tools_handler(body, model, user, __event_emitter__):
skip_files = None
contexts = []
citations = None
task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
if "tool_ids" in body:
print(body["tool_ids"])
for tool_id in body["tool_ids"]:
print(tool_id)
try:
response, citation, file_handler = (
await get_function_call_response(
messages=data["messages"],
files=data.get("files", []),
response, citation, file_handler = await get_function_call_response(
messages=body["messages"],
files=body.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
model=model,
__event_emitter__=__event_emitter__,
)
print(file_handler)
if isinstance(response, str):
context += ("\n" if context != "" else "") + response
contexts.append(response)
if citation:
if citations is None:
citations = [citation]
else:
citations.append(citation)
show_citations = True
if file_handler:
skip_files = True
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"]
print(f"tool_context: {context}")
# If files field is present, generate RAG completions
# If skip_files is True, skip the RAG completions
if "files" in data:
if not skip_files:
data = {**data}
rag_context, rag_citations = get_rag_context(
files=data["files"],
messages=data["messages"],
del body["tool_ids"]
print(f"tool_contexts: {contexts}")
if skip_files:
if "files" in body:
del body["files"]
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
async def chat_completion_files_handler(body):
contexts = []
citations = None
if "files" in body:
files = body["files"]
del body["files"]
contexts, citations = get_rag_context(
files=files,
messages=body["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}")
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
if rag_citations:
citations.extend(rag_citations)
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
del data["files"]
if show_citations and len(citations) > 0:
data_items.append({"citations": citations})
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
):
log.debug(f"request.url.path: {request.url.path}")
try:
body, model, user = await get_body_and_model_and_user(request)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Extract session_id, chat_id and message_id from the request body
session_id = None
if "session_id" in body:
session_id = body["session_id"]
del body["session_id"]
chat_id = None
if "chat_id" in body:
chat_id = body["chat_id"]
del body["chat_id"]
message_id = None
if "id" in body:
message_id = body["id"]
del body["id"]
async def __event_emitter__(data):
await sio.emit(
"chat-events",
{
"chat_id": chat_id,
"message_id": message_id,
"data": data,
},
to=session_id,
)
# Initialize data_items to store additional data to be sent to the client
data_items = []
# Initialize context, and citations
contexts = []
citations = []
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
try:
body, flags = await chat_completion_functions_handler(
body, model, user, __event_emitter__
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
system_prompt, data["messages"]
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
modified_body_bytes = json.dumps(data).encode("utf-8")
try:
body, flags = await chat_completion_tools_handler(
body, model, user, __event_emitter__
)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
try:
body, flags = await chat_completion_files_handler(body)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
# If context is not empty, insert it into the messages
if len(contexts) > 0:
context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"])
body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
# If there are citations, add them to the data_items
if len(citations) > 0:
data_items.append({"citations": citations})
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
......@@ -715,9 +833,6 @@ def filter_pipeline(payload, user):
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]
......@@ -1008,6 +1123,17 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
else:
pass
async def __event_emitter__(data):
await sio.emit(
"chat-events",
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"data": data,
},
to=data["session_id"],
)
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
......@@ -1083,6 +1209,18 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
"__id__": filter_id,
}
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
......@@ -1213,6 +1351,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content={"detail": e.args[1]},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
......@@ -1273,6 +1414,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content={"detail": e.args[1]},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
......@@ -1337,6 +1481,9 @@ Message: """{{prompt}}"""
content={"detail": e.args[1]},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
......
......@@ -10,7 +10,7 @@ python-socketio==5.11.3
python-jose==3.3.0
passlib[bcrypt]==1.7.4
requests==2.32.2
requests==2.32.3
aiohttp==3.9.5
peewee==3.17.5
peewee-migrate==1.12.2
......@@ -30,21 +30,21 @@ openai
anthropic
google-generativeai==0.5.4
langchain==0.2.0
langchain-community==0.2.0
langchain==0.2.6
langchain-community==0.2.6
langchain-chroma==0.1.2
fake-useragent==1.5.1
chromadb==0.5.3
sentence-transformers==2.7.0
sentence-transformers==3.0.1
pypdf==4.2.0
docx2txt==0.8
python-pptx==0.6.23
unstructured==0.14.0
unstructured==0.14.9
Markdown==3.6
pypandoc==1.13
pandas==2.2.2
openpyxl==3.1.2
openpyxl==3.1.5
pyxlsb==1.0.10
xlrd==2.0.1
validators==0.28.1
......@@ -61,7 +61,7 @@ PyJWT[crypto]==2.8.0
authlib==1.3.1
black==24.4.2
langfuse==2.33.0
langfuse==2.36.2
youtube-transcript-api==0.6.2
pytube==15.0.0
......
......@@ -8,9 +8,17 @@ import uuid
import time
def get_last_user_message(messages: List[dict]) -> str:
def get_last_user_message_item(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "user":
return message
return None
def get_last_user_message(messages: List[dict]) -> str:
message = get_last_user_message_item(messages)
if message is not None:
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
......
{
"name": "open-webui",
"version": "0.3.7",
"version": "0.3.8",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "open-webui",
"version": "0.3.7",
"version": "0.3.8",
"dependencies": {
"@codemirror/lang-javascript": "^6.2.2",
"@codemirror/lang-python": "^6.1.6",
......
{
"name": "open-webui",
"version": "0.3.7",
"version": "0.3.8",
"private": true,
"scripts": {
"dev": "npm run pyodide:fetch && vite dev --host",
......
......@@ -32,6 +32,11 @@ type ChunkConfigForm = {
chunk_overlap: number;
};
type ContentExtractConfigForm = {
engine: string;
tika_server_url: string | null;
};
type YoutubeConfigForm = {
language: string[];
translation?: string | null;
......@@ -40,6 +45,7 @@ type YoutubeConfigForm = {
type RAGConfigForm = {
pdf_extract_images?: boolean;
chunk?: ChunkConfigForm;
content_extraction?: ContentExtractConfigForm;
web_loader_ssl_verification?: boolean;
youtube?: YoutubeConfigForm;
};
......
......@@ -37,6 +37,10 @@
let embeddingModel = '';
let rerankingModel = '';
let contentExtractionEngine = 'default';
let tikaServerUrl = '';
let showTikaServerUrl = false;
let chunkSize = 0;
let chunkOverlap = 0;
let pdfExtractImages = true;
......@@ -163,11 +167,20 @@
rerankingModelUpdateHandler();
}
if (contentExtractionEngine === 'tika' && tikaServerUrl === '') {
toast.error($i18n.t('Tika Server URL required.'));
return;
}
const res = await updateRAGConfig(localStorage.token, {
pdf_extract_images: pdfExtractImages,
chunk: {
chunk_overlap: chunkOverlap,
chunk_size: chunkSize
},
content_extraction: {
engine: contentExtractionEngine,
tika_server_url: tikaServerUrl
}
});
......@@ -213,6 +226,10 @@
chunkSize = res.chunk.chunk_size;
chunkOverlap = res.chunk.chunk_overlap;
contentExtractionEngine = res.content_extraction.engine;
tikaServerUrl = res.content_extraction.tika_server_url;
showTikaServerUrl = contentExtractionEngine === 'tika';
}
});
</script>
......@@ -388,7 +405,7 @@
</div>
</div>
<hr class=" dark:border-gray-850 my-1" />
<hr class="dark:border-gray-850" />
<div class="space-y-2" />
<div>
......@@ -562,6 +579,39 @@
<hr class=" dark:border-gray-850" />
<div class="">
<div class="text-sm font-medium">{$i18n.t('Content Extraction')}</div>
<div class="flex w-full justify-between mt-2">
<div class="self-center text-xs font-medium">{$i18n.t('Engine')}</div>
<div class="flex items-center relative">
<select
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
bind:value={contentExtractionEngine}
on:change={(e) => {
showTikaServerUrl = e.target.value === 'tika';
}}
>
<option value="">{$i18n.t('Default')} </option>
<option value="tika">{$i18n.t('Tika')}</option>
</select>
</div>
</div>
{#if showTikaServerUrl}
<div class="flex w-full mt-2">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter Tika Server URL')}
bind:value={tikaServerUrl}
/>
</div>
</div>
{/if}
</div>
<hr class=" dark:border-gray-850" />
<div class=" ">
<div class=" text-sm font-medium">{$i18n.t('Query Params')}</div>
......
......@@ -19,6 +19,7 @@
} from '$lib/apis';
import Spinner from '$lib/components/common/Spinner.svelte';
import Switch from '$lib/components/common/Switch.svelte';
const i18n: Writable<i18nType> = getContext('i18n');
......@@ -476,15 +477,40 @@
</div>
{#if (valves[property] ?? null) !== null}
<div class="flex mt-0.5 space-x-2">
<!-- {valves[property]} -->
<div class="flex mt-0.5 mb-1.5 space-x-2">
<div class=" flex-1">
{#if valves_spec.properties[property]?.enum ?? null}
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={valves[property]}
>
{#each valves_spec.properties[property].enum as option}
<option value={option} selected={option === valves[property]}>
{option}
</option>
{/each}
</select>
{:else if (valves_spec.properties[property]?.type ?? null) === 'boolean'}
<div class="flex justify-between items-center">
<div class="text-xs text-gray-500">
{valves[property] ? 'Enabled' : 'Disabled'}
</div>
<div class=" pr-2">
<Switch bind:state={valves[property]} />
</div>
</div>
{:else}
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
type="text"
placeholder={valves_spec.properties[property].title}
bind:value={valves[property]}
autocomplete="off"
required
/>
{/if}
</div>
</div>
{/if}
......
......@@ -126,6 +126,27 @@
})();
}
const chatEventHandler = async (data) => {
if (data.chat_id === $chatId) {
await tick();
console.log(data);
let message = history.messages[data.message_id];
const status = {
done: data?.data?.done ?? null,
description: data?.data?.status ?? null
};
if (message.statusHistory) {
message.statusHistory.push(status);
} else {
message.statusHistory = [status];
}
messages = messages;
}
};
onMount(async () => {
const onMessageHandler = async (event) => {
if (event.origin === window.origin) {
......@@ -163,6 +184,8 @@
};
window.addEventListener('message', onMessageHandler);
$socket.on('chat-events', chatEventHandler);
if (!$chatId) {
chatId.subscribe(async (value) => {
if (!value) {
......@@ -177,6 +200,8 @@
return () => {
window.removeEventListener('message', onMessageHandler);
$socket.off('chat-events');
};
});
......@@ -302,7 +327,7 @@
}
};
const chatCompletedHandler = async (modelId, messages) => {
const chatCompletedHandler = async (modelId, responseMessageId, messages) => {
await mermaid.run({
querySelector: '.mermaid'
});
......@@ -316,7 +341,9 @@
info: m.info ? m.info : undefined,
timestamp: m.timestamp
})),
chat_id: $chatId
chat_id: $chatId,
session_id: $socket?.id,
id: responseMessageId
}).catch((error) => {
toast.error(error);
messages.at(-1).error = { content: error };
......@@ -665,6 +692,7 @@
await tick();
const [res, controller] = await generateChatCompletion(localStorage.token, {
stream: true,
model: model.id,
messages: messagesBody,
options: {
......@@ -682,8 +710,9 @@
keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
citations: files.length > 0 ? true : undefined,
chat_id: $chatId
session_id: $socket?.id,
chat_id: $chatId,
id: responseMessageId
});
if (res && res.ok) {
......@@ -704,7 +733,7 @@
controller.abort('User: Stop Response');
} else {
const messages = createMessagesList(responseMessageId);
await chatCompletedHandler(model.id, messages);
await chatCompletedHandler(model.id, responseMessageId, messages);
}
_response = responseMessage.content;
......@@ -912,8 +941,8 @@
const [res, controller] = await generateOpenAIChatCompletion(
localStorage.token,
{
model: model.id,
stream: true,
model: model.id,
stream_options:
model.info?.meta?.capabilities?.usage ?? false
? {
......@@ -983,9 +1012,9 @@
max_tokens: $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
citations: files.length > 0 ? true : undefined,
chat_id: $chatId
session_id: $socket?.id,
chat_id: $chatId,
id: responseMessageId
},
`${WEBUI_BASE_URL}/api`
);
......@@ -1014,7 +1043,7 @@
} else {
const messages = createMessagesList(responseMessageId);
await chatCompletedHandler(model.id, messages);
await chatCompletedHandler(model.id, responseMessageId, messages);
}
_response = responseMessage.content;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment