Commit f5487628 authored by Morgan Blangeois's avatar Morgan Blangeois
Browse files

Resolve merge conflicts in French translations

parents 2fedd91e 2c061777
...@@ -10,7 +10,8 @@ node_modules ...@@ -10,7 +10,8 @@ node_modules
vite.config.js.timestamp-* vite.config.js.timestamp-*
vite.config.ts.timestamp-* vite.config.ts.timestamp-*
__pycache__ __pycache__
.env .idea
venv
_old _old
uploads uploads
.ipynb_checkpoints .ipynb_checkpoints
......
...@@ -306,3 +306,4 @@ dist ...@@ -306,3 +306,4 @@ dist
# cypress artifacts # cypress artifacts
cypress/videos cypress/videos
cypress/screenshots cypress/screenshots
.vscode/settings.json
...@@ -14,7 +14,6 @@ from fastapi import ( ...@@ -14,7 +14,6 @@ from fastapi import (
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from pydantic import BaseModel from pydantic import BaseModel
import uuid import uuid
...@@ -277,6 +276,8 @@ def transcribe( ...@@ -277,6 +276,8 @@ def transcribe(
f.close() f.close()
if app.state.config.STT_ENGINE == "": if app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
whisper_kwargs = { whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL, "model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type, "device": whisper_device_type,
......
...@@ -12,7 +12,6 @@ from fastapi import ( ...@@ -12,7 +12,6 @@ from fastapi import (
Form, Form,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
......
...@@ -25,6 +25,7 @@ from utils.task import prompt_template ...@@ -25,6 +25,7 @@ from utils.task import prompt_template
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
ENABLE_OPENAI_API, ENABLE_OPENAI_API,
AIOHTTP_CLIENT_TIMEOUT,
OPENAI_API_BASE_URLS, OPENAI_API_BASE_URLS,
OPENAI_API_KEYS, OPENAI_API_KEYS,
CACHE_DIR, CACHE_DIR,
...@@ -463,7 +464,9 @@ async def generate_chat_completion( ...@@ -463,7 +464,9 @@ async def generate_chat_completion(
streaming = False streaming = False
try: try:
session = aiohttp.ClientSession(trust_env=True) session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
r = await session.request( r = await session.request(
method="POST", method="POST",
url=f"{url}/chat/completions", url=f"{url}/chat/completions",
......
...@@ -48,8 +48,6 @@ import mimetypes ...@@ -48,8 +48,6 @@ import mimetypes
import uuid import uuid
import json import json
import sentence_transformers
from apps.webui.models.documents import ( from apps.webui.models.documents import (
Documents, Documents,
DocumentForm, DocumentForm,
...@@ -93,6 +91,8 @@ from config import ( ...@@ -93,6 +91,8 @@ from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
UPLOAD_DIR, UPLOAD_DIR,
DOCS_DIR, DOCS_DIR,
CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL,
RAG_TOP_K, RAG_TOP_K,
RAG_RELEVANCE_THRESHOLD, RAG_RELEVANCE_THRESHOLD,
RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_ENGINE,
...@@ -148,6 +148,9 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( ...@@ -148,6 +148,9 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
) )
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.CHUNK_SIZE = CHUNK_SIZE app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
...@@ -190,6 +193,8 @@ def update_embedding_model( ...@@ -190,6 +193,8 @@ def update_embedding_model(
update_model: bool = False, update_model: bool = False,
): ):
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "": if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
import sentence_transformers
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model), get_model_path(embedding_model, update_model),
device=DEVICE_TYPE, device=DEVICE_TYPE,
...@@ -204,6 +209,8 @@ def update_reranking_model( ...@@ -204,6 +209,8 @@ def update_reranking_model(
update_model: bool = False, update_model: bool = False,
): ):
if reranking_model: if reranking_model:
import sentence_transformers
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model), get_model_path(reranking_model, update_model),
device=DEVICE_TYPE, device=DEVICE_TYPE,
...@@ -388,6 +395,10 @@ async def get_rag_config(user=Depends(get_admin_user)): ...@@ -388,6 +395,10 @@ async def get_rag_config(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": { "chunk": {
"chunk_size": app.state.config.CHUNK_SIZE, "chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP, "chunk_overlap": app.state.config.CHUNK_OVERLAP,
...@@ -417,6 +428,11 @@ async def get_rag_config(user=Depends(get_admin_user)): ...@@ -417,6 +428,11 @@ async def get_rag_config(user=Depends(get_admin_user)):
} }
class ContentExtractionConfig(BaseModel):
engine: str = ""
tika_server_url: Optional[str] = None
class ChunkParamUpdateForm(BaseModel): class ChunkParamUpdateForm(BaseModel):
chunk_size: int chunk_size: int
chunk_overlap: int chunk_overlap: int
...@@ -450,6 +466,7 @@ class WebConfig(BaseModel): ...@@ -450,6 +466,7 @@ class WebConfig(BaseModel):
class ConfigUpdateForm(BaseModel): class ConfigUpdateForm(BaseModel):
pdf_extract_images: Optional[bool] = None pdf_extract_images: Optional[bool] = None
content_extraction: Optional[ContentExtractionConfig] = None
chunk: Optional[ChunkParamUpdateForm] = None chunk: Optional[ChunkParamUpdateForm] = None
youtube: Optional[YoutubeLoaderConfig] = None youtube: Optional[YoutubeLoaderConfig] = None
web: Optional[WebConfig] = None web: Optional[WebConfig] = None
...@@ -463,6 +480,11 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ ...@@ -463,6 +480,11 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
else app.state.config.PDF_EXTRACT_IMAGES else app.state.config.PDF_EXTRACT_IMAGES
) )
if form_data.content_extraction is not None:
log.info(f"Updating text settings: {form_data.content_extraction}")
app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine
app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
if form_data.chunk is not None: if form_data.chunk is not None:
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
...@@ -499,6 +521,10 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ ...@@ -499,6 +521,10 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
return { return {
"status": True, "status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES, "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": { "chunk": {
"chunk_size": app.state.config.CHUNK_SIZE, "chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP, "chunk_overlap": app.state.config.CHUNK_OVERLAP,
...@@ -985,6 +1011,41 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -985,6 +1011,41 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
return False return False
class TikaLoader:
def __init__(self, file_path, mime_type=None):
self.file_path = file_path
self.mime_type = mime_type
def load(self) -> List[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
if self.mime_type is not None:
headers = {"Content-Type": self.mime_type}
else:
headers = {}
endpoint = app.state.config.TIKA_SERVER_URL
if not endpoint.endswith("/"):
endpoint += "/"
endpoint += "tika/text"
r = requests.put(endpoint, data=data, headers=headers)
if r.ok:
raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
log.info("Tika extracted text: %s", text)
return [Document(page_content=text, metadata=headers)]
else:
raise Exception(f"Error calling Tika: {r.reason}")
def get_loader(filename: str, file_content_type: str, file_path: str): def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower() file_ext = filename.split(".")[-1].lower()
known_type = True known_type = True
...@@ -1035,6 +1096,17 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ...@@ -1035,6 +1096,17 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"msg", "msg",
] ]
if (
app.state.config.CONTENT_EXTRACTION_ENGINE == "tika"
and app.state.config.TIKA_SERVER_URL
):
if file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TikaLoader(file_path, file_content_type)
else:
if file_ext == "pdf": if file_ext == "pdf":
loader = PyPDFLoader( loader = PyPDFLoader(
file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
......
...@@ -294,15 +294,17 @@ def get_rag_context( ...@@ -294,15 +294,17 @@ def get_rag_context(
extracted_collections.extend(collection_names) extracted_collections.extend(collection_names)
context_string = "" contexts = []
citations = [] citations = []
for context in relevant_contexts: for context in relevant_contexts:
try: try:
if "documents" in context: if "documents" in context:
context_string += "\n\n".join( contexts.append(
"\n\n".join(
[text for text in context["documents"][0] if text is not None] [text for text in context["documents"][0] if text is not None]
) )
)
if "metadatas" in context: if "metadatas" in context:
citations.append( citations.append(
...@@ -315,9 +317,7 @@ def get_rag_context( ...@@ -315,9 +317,7 @@ def get_rag_context(
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
context_string = context_string.strip() return contexts, citations
return context_string, citations
def get_model_path(model: str, update_model: bool = False): def get_model_path(model: str, update_model: bool = False):
...@@ -442,8 +442,6 @@ from langchain_core.documents import BaseDocumentCompressor, Document ...@@ -442,8 +442,6 @@ from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra from langchain_core.pydantic_v1 import Extra
from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor): class RerankCompressor(BaseDocumentCompressor):
embedding_function: Any embedding_function: Any
...@@ -468,6 +466,8 @@ class RerankCompressor(BaseDocumentCompressor): ...@@ -468,6 +466,8 @@ class RerankCompressor(BaseDocumentCompressor):
[(query, doc.page_content) for doc in documents] [(query, doc.page_content) for doc in documents]
) )
else: else:
from sentence_transformers import util
query_embedding = self.embedding_function(query) query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function( document_embedding = self.embedding_function(
[doc.page_content for doc in documents] [doc.page_content for doc in documents]
......
...@@ -259,6 +259,9 @@ async def generate_function_chat_completion(form_data, user): ...@@ -259,6 +259,9 @@ async def generate_function_chat_completion(form_data, user):
if isinstance(line, BaseModel): if isinstance(line, BaseModel):
line = line.model_dump_json() line = line.model_dump_json()
line = f"data: {line}" line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try: try:
line = line.decode("utf-8") line = line.decode("utf-8")
except: except:
......
...@@ -214,8 +214,7 @@ class FunctionsTable: ...@@ -214,8 +214,7 @@ class FunctionsTable:
user_settings["functions"]["valves"][id] = valves user_settings["functions"]["valves"][id] = valves
# Update the user settings in the database # Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings}) Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
return user_settings["functions"]["valves"][id] return user_settings["functions"]["valves"][id]
except Exception as e: except Exception as e:
......
...@@ -170,8 +170,7 @@ class ToolsTable: ...@@ -170,8 +170,7 @@ class ToolsTable:
user_settings["tools"]["valves"][id] = valves user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database # Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings}) Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
return user_settings["tools"]["valves"][id] return user_settings["tools"]["valves"][id]
except Exception as e: except Exception as e:
......
...@@ -5,9 +5,8 @@ import importlib.metadata ...@@ -5,9 +5,8 @@ import importlib.metadata
import pkgutil import pkgutil
import chromadb import chromadb
from chromadb import Settings from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from typing import TypeVar, Generic, Union from typing import TypeVar, Generic
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
...@@ -19,7 +18,6 @@ import markdown ...@@ -19,7 +18,6 @@ import markdown
import requests import requests
import shutil import shutil
from secrets import token_bytes
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
#################################### ####################################
...@@ -768,12 +766,14 @@ class BannerModel(BaseModel): ...@@ -768,12 +766,14 @@ class BannerModel(BaseModel):
dismissible: bool dismissible: bool
timestamp: int timestamp: int
try:
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
banners = [BannerModel(**banner) for banner in banners]
except Exception as e:
print(f"Error loading WEBUI_BANNERS: {e}")
banners = []
WEBUI_BANNERS = PersistentConfig( WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
"WEBUI_BANNERS",
"ui.banners",
[BannerModel(**banner) for banner in json.loads("[]")],
)
SHOW_ADMIN_DETAILS = PersistentConfig( SHOW_ADMIN_DETAILS = PersistentConfig(
...@@ -885,6 +885,22 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get( ...@@ -885,6 +885,22 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
if WEBUI_AUTH and WEBUI_SECRET_KEY == "": if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
####################################
# RAG document content extraction
####################################
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
"rag.CONTENT_EXTRACTION_ENGINE",
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
)
TIKA_SERVER_URL = PersistentConfig(
"TIKA_SERVER_URL",
"rag.tika_server_url",
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
)
#################################### ####################################
# RAG # RAG
#################################### ####################################
......
...@@ -33,7 +33,7 @@ from starlette.middleware.sessions import SessionMiddleware ...@@ -33,7 +33,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import app as socket_app from apps.socket.main import sio, app as socket_app
from apps.ollama.main import ( from apps.ollama.main import (
app as ollama_app, app as ollama_app,
OpenAIChatCompletionForm, OpenAIChatCompletionForm,
...@@ -212,8 +212,79 @@ origins = ["*"] ...@@ -212,8 +212,79 @@ origins = ["*"]
################################## ##################################
async def get_body_and_model_and_user(request):
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {}
model_id = body["model"]
if model_id not in app.state.MODELS:
raise "Model not found"
model = app.state.MODELS[model_id]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
return body, model, user
def get_task_model_id(default_model_id):
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
return task_model_id
def get_filter_function_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
return filter_ids
async def get_function_call_response( async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user messages,
files,
tool_id,
template,
task_model_id,
user,
model,
__event_emitter__=None,
): ):
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
...@@ -350,6 +421,13 @@ async def get_function_call_response( ...@@ -350,6 +421,13 @@ async def get_function_call_response(
"__id__": tool_id, "__id__": tool_id,
} }
if "__event_emitter__" in sig.parameters:
# Call the function with the '__event_emitter__' parameter included
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if inspect.iscoroutinefunction(function): if inspect.iscoroutinefunction(function):
function_result = await function(**params) function_result = await function(**params)
else: else:
...@@ -373,68 +451,10 @@ async def get_function_call_response( ...@@ -373,68 +451,10 @@ async def get_function_call_response(
return None, None, False return None, None, False
class ChatCompletionMiddleware(BaseHTTPMiddleware): async def chat_completion_functions_handler(body, model, user, __event_emitter__):
async def dispatch(self, request: Request, call_next): skip_files = None
data_items = []
show_citations = False
citations = []
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
):
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
data = json.loads(body_str) if body_str else {}
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
if data.get("citations"):
show_citations = True
del data["citations"]
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get(
"priority", 0
)
return 0
filter_ids = [
function.id for function in Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type(
"filter", active_only=True
)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority) filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids: for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id) filter = Functions.get_function_by_id(filter_id)
if filter: if filter:
...@@ -464,7 +484,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -464,7 +484,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(inlet) sig = inspect.signature(inlet)
params = {"body": data} params = {"body": body}
if "__user__" in sig.parameters: if "__user__" in sig.parameters:
__user__ = { __user__ = {
...@@ -492,108 +512,206 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -492,108 +512,206 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
"__id__": filter_id, "__id__": filter_id,
} }
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if inspect.iscoroutinefunction(inlet): if inspect.iscoroutinefunction(inlet):
data = await inlet(**params) body = await inlet(**params)
else: else:
data = inlet(**params) body = inlet(**params)
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
return JSONResponse( raise e
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)}, if skip_files:
) if "files" in body:
del body["files"]
return body, {}
# Set the task model
task_model_id = data["model"]
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"]) async def chat_completion_tools_handler(body, model, user, __event_emitter__):
context = "" skip_files = None
contexts = []
citations = None
task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
if "tool_ids" in data: if "tool_ids" in body:
print(data["tool_ids"]) print(body["tool_ids"])
for tool_id in data["tool_ids"]: for tool_id in body["tool_ids"]:
print(tool_id) print(tool_id)
try: try:
response, citation, file_handler = ( response, citation, file_handler = await get_function_call_response(
await get_function_call_response( messages=body["messages"],
messages=data["messages"], files=body.get("files", []),
files=data.get("files", []),
tool_id=tool_id, tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id, task_model_id=task_model_id,
user=user, user=user,
) model=model,
__event_emitter__=__event_emitter__,
) )
print(file_handler) print(file_handler)
if isinstance(response, str): if isinstance(response, str):
context += ("\n" if context != "" else "") + response contexts.append(response)
if citation: if citation:
if citations is None:
citations = [citation]
else:
citations.append(citation) citations.append(citation)
show_citations = True
if file_handler: if file_handler:
skip_files = True skip_files = True
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
del data["tool_ids"] del body["tool_ids"]
print(f"tool_contexts: {contexts}")
print(f"tool_context: {context}")
if skip_files:
# If files field is present, generate RAG completions if "files" in body:
# If skip_files is True, skip the RAG completions del body["files"]
if "files" in data:
if not skip_files: return body, {
data = {**data} **({"contexts": contexts} if contexts is not None else {}),
rag_context, rag_citations = get_rag_context( **({"citations": citations} if citations is not None else {}),
files=data["files"], }
messages=data["messages"],
async def chat_completion_files_handler(body):
contexts = []
citations = None
if "files" in body:
files = body["files"]
del body["files"]
contexts, citations = get_rag_context(
files=files,
messages=body["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION, embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K, k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf, reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD, r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
) )
if rag_context:
context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}") log.debug(f"rag_contexts: {contexts}, citations: {citations}")
if rag_citations: return body, {
citations.extend(rag_citations) **({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
del data["files"]
if show_citations and len(citations) > 0: class ChatCompletionMiddleware(BaseHTTPMiddleware):
data_items.append({"citations": citations}) async def dispatch(self, request: Request, call_next):
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
):
log.debug(f"request.url.path: {request.url.path}")
try:
body, model, user = await get_body_and_model_and_user(request)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Extract session_id, chat_id and message_id from the request body
session_id = None
if "session_id" in body:
session_id = body["session_id"]
del body["session_id"]
chat_id = None
if "chat_id" in body:
chat_id = body["chat_id"]
del body["chat_id"]
message_id = None
if "id" in body:
message_id = body["id"]
del body["id"]
async def __event_emitter__(data):
await sio.emit(
"chat-events",
{
"chat_id": chat_id,
"message_id": message_id,
"data": data,
},
to=session_id,
)
# Initialize data_items to store additional data to be sent to the client
data_items = []
# Initialize context, and citations
contexts = []
citations = []
if context != "": try:
system_prompt = rag_template( body, flags = await chat_completion_functions_handler(
rag_app.state.config.RAG_TEMPLATE, context, prompt body, model, user, __event_emitter__
) )
print(system_prompt) except Exception as e:
data["messages"] = add_or_update_system_message( return JSONResponse(
system_prompt, data["messages"] status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
) )
modified_body_bytes = json.dumps(data).encode("utf-8") try:
body, flags = await chat_completion_tools_handler(
body, model, user, __event_emitter__
)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
try:
body, flags = await chat_completion_files_handler(body)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
# If context is not empty, insert it into the messages
if len(contexts) > 0:
context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"])
body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
# If there are citations, add them to the data_items
if len(citations) > 0:
data_items.append({"citations": citations})
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
request._body = modified_body_bytes request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length # Set custom header to ensure content-length matches new body length
...@@ -715,9 +833,6 @@ def filter_pipeline(payload, user): ...@@ -715,9 +833,6 @@ def filter_pipeline(payload, user):
pass pass
if "pipeline" not in app.state.MODELS[model_id]: if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload: if "title" in payload:
del payload["title"] del payload["title"]
...@@ -1008,6 +1123,17 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1008,6 +1123,17 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
else: else:
pass pass
async def __event_emitter__(data):
await sio.emit(
"chat-events",
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"data": data,
},
to=data["session_id"],
)
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"): if function is not None and hasattr(function, "valves"):
...@@ -1083,6 +1209,18 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1083,6 +1209,18 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
"__id__": filter_id, "__id__": filter_id,
} }
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if "__event_emitter__" in sig.parameters:
params = {
**params,
"__event_emitter__": __event_emitter__,
}
if inspect.iscoroutinefunction(outlet): if inspect.iscoroutinefunction(outlet):
data = await outlet(**params) data = await outlet(**params)
else: else:
...@@ -1213,6 +1351,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -1213,6 +1351,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
...@@ -1273,6 +1414,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -1273,6 +1414,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
...@@ -1337,6 +1481,9 @@ Message: """{{prompt}}""" ...@@ -1337,6 +1481,9 @@ Message: """{{prompt}}"""
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
......
...@@ -10,7 +10,7 @@ python-socketio==5.11.3 ...@@ -10,7 +10,7 @@ python-socketio==5.11.3
python-jose==3.3.0 python-jose==3.3.0
passlib[bcrypt]==1.7.4 passlib[bcrypt]==1.7.4
requests==2.32.2 requests==2.32.3
aiohttp==3.9.5 aiohttp==3.9.5
peewee==3.17.5 peewee==3.17.5
peewee-migrate==1.12.2 peewee-migrate==1.12.2
...@@ -30,21 +30,21 @@ openai ...@@ -30,21 +30,21 @@ openai
anthropic anthropic
google-generativeai==0.5.4 google-generativeai==0.5.4
langchain==0.2.0 langchain==0.2.6
langchain-community==0.2.0 langchain-community==0.2.6
langchain-chroma==0.1.2 langchain-chroma==0.1.2
fake-useragent==1.5.1 fake-useragent==1.5.1
chromadb==0.5.3 chromadb==0.5.3
sentence-transformers==2.7.0 sentence-transformers==3.0.1
pypdf==4.2.0 pypdf==4.2.0
docx2txt==0.8 docx2txt==0.8
python-pptx==0.6.23 python-pptx==0.6.23
unstructured==0.14.0 unstructured==0.14.9
Markdown==3.6 Markdown==3.6
pypandoc==1.13 pypandoc==1.13
pandas==2.2.2 pandas==2.2.2
openpyxl==3.1.2 openpyxl==3.1.5
pyxlsb==1.0.10 pyxlsb==1.0.10
xlrd==2.0.1 xlrd==2.0.1
validators==0.28.1 validators==0.28.1
...@@ -61,7 +61,7 @@ PyJWT[crypto]==2.8.0 ...@@ -61,7 +61,7 @@ PyJWT[crypto]==2.8.0
authlib==1.3.1 authlib==1.3.1
black==24.4.2 black==24.4.2
langfuse==2.33.0 langfuse==2.36.2
youtube-transcript-api==0.6.2 youtube-transcript-api==0.6.2
pytube==15.0.0 pytube==15.0.0
......
...@@ -8,9 +8,17 @@ import uuid ...@@ -8,9 +8,17 @@ import uuid
import time import time
def get_last_user_message(messages: List[dict]) -> str: def get_last_user_message_item(messages: List[dict]) -> str:
for message in reversed(messages): for message in reversed(messages):
if message["role"] == "user": if message["role"] == "user":
return message
return None
def get_last_user_message(messages: List[dict]) -> str:
message = get_last_user_message_item(messages)
if message is not None:
if isinstance(message["content"], list): if isinstance(message["content"], list):
for item in message["content"]: for item in message["content"]:
if item["type"] == "text": if item["type"] == "text":
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.3.7", "version": "0.3.8",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "open-webui", "name": "open-webui",
"version": "0.3.7", "version": "0.3.8",
"dependencies": { "dependencies": {
"@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-javascript": "^6.2.2",
"@codemirror/lang-python": "^6.1.6", "@codemirror/lang-python": "^6.1.6",
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.3.7", "version": "0.3.8",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "npm run pyodide:fetch && vite dev --host", "dev": "npm run pyodide:fetch && vite dev --host",
......
...@@ -32,6 +32,11 @@ type ChunkConfigForm = { ...@@ -32,6 +32,11 @@ type ChunkConfigForm = {
chunk_overlap: number; chunk_overlap: number;
}; };
type ContentExtractConfigForm = {
engine: string;
tika_server_url: string | null;
};
type YoutubeConfigForm = { type YoutubeConfigForm = {
language: string[]; language: string[];
translation?: string | null; translation?: string | null;
...@@ -40,6 +45,7 @@ type YoutubeConfigForm = { ...@@ -40,6 +45,7 @@ type YoutubeConfigForm = {
type RAGConfigForm = { type RAGConfigForm = {
pdf_extract_images?: boolean; pdf_extract_images?: boolean;
chunk?: ChunkConfigForm; chunk?: ChunkConfigForm;
content_extraction?: ContentExtractConfigForm;
web_loader_ssl_verification?: boolean; web_loader_ssl_verification?: boolean;
youtube?: YoutubeConfigForm; youtube?: YoutubeConfigForm;
}; };
......
...@@ -37,6 +37,10 @@ ...@@ -37,6 +37,10 @@
let embeddingModel = ''; let embeddingModel = '';
let rerankingModel = ''; let rerankingModel = '';
let contentExtractionEngine = 'default';
let tikaServerUrl = '';
let showTikaServerUrl = false;
let chunkSize = 0; let chunkSize = 0;
let chunkOverlap = 0; let chunkOverlap = 0;
let pdfExtractImages = true; let pdfExtractImages = true;
...@@ -163,11 +167,20 @@ ...@@ -163,11 +167,20 @@
rerankingModelUpdateHandler(); rerankingModelUpdateHandler();
} }
if (contentExtractionEngine === 'tika' && tikaServerUrl === '') {
toast.error($i18n.t('Tika Server URL required.'));
return;
}
const res = await updateRAGConfig(localStorage.token, { const res = await updateRAGConfig(localStorage.token, {
pdf_extract_images: pdfExtractImages, pdf_extract_images: pdfExtractImages,
chunk: { chunk: {
chunk_overlap: chunkOverlap, chunk_overlap: chunkOverlap,
chunk_size: chunkSize chunk_size: chunkSize
},
content_extraction: {
engine: contentExtractionEngine,
tika_server_url: tikaServerUrl
} }
}); });
...@@ -213,6 +226,10 @@ ...@@ -213,6 +226,10 @@
chunkSize = res.chunk.chunk_size; chunkSize = res.chunk.chunk_size;
chunkOverlap = res.chunk.chunk_overlap; chunkOverlap = res.chunk.chunk_overlap;
contentExtractionEngine = res.content_extraction.engine;
tikaServerUrl = res.content_extraction.tika_server_url;
showTikaServerUrl = contentExtractionEngine === 'tika';
} }
}); });
</script> </script>
...@@ -388,7 +405,7 @@ ...@@ -388,7 +405,7 @@
</div> </div>
</div> </div>
<hr class=" dark:border-gray-850 my-1" /> <hr class="dark:border-gray-850" />
<div class="space-y-2" /> <div class="space-y-2" />
<div> <div>
...@@ -562,6 +579,39 @@ ...@@ -562,6 +579,39 @@
<hr class=" dark:border-gray-850" /> <hr class=" dark:border-gray-850" />
<div class="">
<div class="text-sm font-medium">{$i18n.t('Content Extraction')}</div>
<div class="flex w-full justify-between mt-2">
<div class="self-center text-xs font-medium">{$i18n.t('Engine')}</div>
<div class="flex items-center relative">
<select
class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
bind:value={contentExtractionEngine}
on:change={(e) => {
showTikaServerUrl = e.target.value === 'tika';
}}
>
<option value="">{$i18n.t('Default')} </option>
<option value="tika">{$i18n.t('Tika')}</option>
</select>
</div>
</div>
{#if showTikaServerUrl}
<div class="flex w-full mt-2">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Enter Tika Server URL')}
bind:value={tikaServerUrl}
/>
</div>
</div>
{/if}
</div>
<hr class=" dark:border-gray-850" />
<div class=" "> <div class=" ">
<div class=" text-sm font-medium">{$i18n.t('Query Params')}</div> <div class=" text-sm font-medium">{$i18n.t('Query Params')}</div>
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
} from '$lib/apis'; } from '$lib/apis';
import Spinner from '$lib/components/common/Spinner.svelte'; import Spinner from '$lib/components/common/Spinner.svelte';
import Switch from '$lib/components/common/Switch.svelte';
const i18n: Writable<i18nType> = getContext('i18n'); const i18n: Writable<i18nType> = getContext('i18n');
...@@ -476,15 +477,40 @@ ...@@ -476,15 +477,40 @@
</div> </div>
{#if (valves[property] ?? null) !== null} {#if (valves[property] ?? null) !== null}
<div class="flex mt-0.5 space-x-2"> <!-- {valves[property]} -->
<div class="flex mt-0.5 mb-1.5 space-x-2">
<div class=" flex-1"> <div class=" flex-1">
{#if valves_spec.properties[property]?.enum ?? null}
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={valves[property]}
>
{#each valves_spec.properties[property].enum as option}
<option value={option} selected={option === valves[property]}>
{option}
</option>
{/each}
</select>
{:else if (valves_spec.properties[property]?.type ?? null) === 'boolean'}
<div class="flex justify-between items-center">
<div class="text-xs text-gray-500">
{valves[property] ? 'Enabled' : 'Disabled'}
</div>
<div class=" pr-2">
<Switch bind:state={valves[property]} />
</div>
</div>
{:else}
<input <input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
type="text" type="text"
placeholder={valves_spec.properties[property].title} placeholder={valves_spec.properties[property].title}
bind:value={valves[property]} bind:value={valves[property]}
autocomplete="off" autocomplete="off"
required
/> />
{/if}
</div> </div>
</div> </div>
{/if} {/if}
......
...@@ -126,6 +126,27 @@ ...@@ -126,6 +126,27 @@
})(); })();
} }
const chatEventHandler = async (data) => {
if (data.chat_id === $chatId) {
await tick();
console.log(data);
let message = history.messages[data.message_id];
const status = {
done: data?.data?.done ?? null,
description: data?.data?.status ?? null
};
if (message.statusHistory) {
message.statusHistory.push(status);
} else {
message.statusHistory = [status];
}
messages = messages;
}
};
onMount(async () => { onMount(async () => {
const onMessageHandler = async (event) => { const onMessageHandler = async (event) => {
if (event.origin === window.origin) { if (event.origin === window.origin) {
...@@ -163,6 +184,8 @@ ...@@ -163,6 +184,8 @@
}; };
window.addEventListener('message', onMessageHandler); window.addEventListener('message', onMessageHandler);
$socket.on('chat-events', chatEventHandler);
if (!$chatId) { if (!$chatId) {
chatId.subscribe(async (value) => { chatId.subscribe(async (value) => {
if (!value) { if (!value) {
...@@ -177,6 +200,8 @@ ...@@ -177,6 +200,8 @@
return () => { return () => {
window.removeEventListener('message', onMessageHandler); window.removeEventListener('message', onMessageHandler);
$socket.off('chat-events');
}; };
}); });
...@@ -302,7 +327,7 @@ ...@@ -302,7 +327,7 @@
} }
}; };
const chatCompletedHandler = async (modelId, messages) => { const chatCompletedHandler = async (modelId, responseMessageId, messages) => {
await mermaid.run({ await mermaid.run({
querySelector: '.mermaid' querySelector: '.mermaid'
}); });
...@@ -316,7 +341,9 @@ ...@@ -316,7 +341,9 @@
info: m.info ? m.info : undefined, info: m.info ? m.info : undefined,
timestamp: m.timestamp timestamp: m.timestamp
})), })),
chat_id: $chatId chat_id: $chatId,
session_id: $socket?.id,
id: responseMessageId
}).catch((error) => { }).catch((error) => {
toast.error(error); toast.error(error);
messages.at(-1).error = { content: error }; messages.at(-1).error = { content: error };
...@@ -665,6 +692,7 @@ ...@@ -665,6 +692,7 @@
await tick(); await tick();
const [res, controller] = await generateChatCompletion(localStorage.token, { const [res, controller] = await generateChatCompletion(localStorage.token, {
stream: true,
model: model.id, model: model.id,
messages: messagesBody, messages: messagesBody,
options: { options: {
...@@ -682,8 +710,9 @@ ...@@ -682,8 +710,9 @@
keep_alive: $settings.keepAlive ?? undefined, keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined, files: files.length > 0 ? files : undefined,
citations: files.length > 0 ? true : undefined, session_id: $socket?.id,
chat_id: $chatId chat_id: $chatId,
id: responseMessageId
}); });
if (res && res.ok) { if (res && res.ok) {
...@@ -704,7 +733,7 @@ ...@@ -704,7 +733,7 @@
controller.abort('User: Stop Response'); controller.abort('User: Stop Response');
} else { } else {
const messages = createMessagesList(responseMessageId); const messages = createMessagesList(responseMessageId);
await chatCompletedHandler(model.id, messages); await chatCompletedHandler(model.id, responseMessageId, messages);
} }
_response = responseMessage.content; _response = responseMessage.content;
...@@ -912,8 +941,8 @@ ...@@ -912,8 +941,8 @@
const [res, controller] = await generateOpenAIChatCompletion( const [res, controller] = await generateOpenAIChatCompletion(
localStorage.token, localStorage.token,
{ {
model: model.id,
stream: true, stream: true,
model: model.id,
stream_options: stream_options:
model.info?.meta?.capabilities?.usage ?? false model.info?.meta?.capabilities?.usage ?? false
? { ? {
...@@ -983,9 +1012,9 @@ ...@@ -983,9 +1012,9 @@
max_tokens: $settings?.params?.max_tokens ?? undefined, max_tokens: $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined, files: files.length > 0 ? files : undefined,
citations: files.length > 0 ? true : undefined, session_id: $socket?.id,
chat_id: $chatId,
chat_id: $chatId id: responseMessageId
}, },
`${WEBUI_BASE_URL}/api` `${WEBUI_BASE_URL}/api`
); );
...@@ -1014,7 +1043,7 @@ ...@@ -1014,7 +1043,7 @@
} else { } else {
const messages = createMessagesList(responseMessageId); const messages = createMessagesList(responseMessageId);
await chatCompletedHandler(model.id, messages); await chatCompletedHandler(model.id, responseMessageId, messages);
} }
_response = responseMessage.content; _response = responseMessage.content;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment