Unverified Commit 6e843ab5 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3882 from open-webui/dev

0.3.9
parents eff736ac b3a0d47a
...@@ -5,6 +5,34 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,34 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.9] - 2024-07-17
### Added
- **📁 Files Chat Controls**: We've reverted to the old file handling behavior where uploaded files are always included. You can now manage files directly within the chat controls section, giving you the ability to remove files as needed.
- **🔧 "Action" Function Support**: Introducing a new "Action" function to write custom buttons to the message toolbar. This feature enables more interactive messaging, with documentation coming soon.
- **📜 Citations Handling**: For newly uploaded files in documents workspace, citations will now display the actual filename. Additionally, you can click on these filenames to open the file in a new tab for easier access.
- **🛠️ Event Emitter and Call Updates**: Enhanced 'event_emitter' to allow message replacement and 'event_call' to support text input for Tools and Functions. Detailed documentation will be provided shortly.
- **🎨 Styling Refactor**: Various styling updates for a cleaner and more cohesive user interface.
- **🌐 Enhanced Translations**: Improved translations for Catalan, Ukrainian, and Brazilian Portuguese.
### Fixed
- **🔧 Chat Controls Priority**: Resolved an issue where Chat Controls values were being overridden by model information parameters. The priority is now Chat Controls, followed by Global Settings, then Model Settings.
- **🪲 Debug Logs**: Fixed an issue where debug logs were not being logged properly.
- **🔑 Automatic1111 Auth Key**: The auth key for Automatic1111 is no longer required.
- **📝 Title Generation**: Ensured that the title generation runs only once, even when multiple models are in a chat.
- **✅ Boolean Values in Params**: Added support for boolean values in parameters.
- **🖼️ Files Overlay Styling**: Fixed the styling issue with the files overlay.
### Changed
- **⬆️ Dependency Updates**
- Upgraded 'pydantic' from version 2.7.1 to 2.8.2.
- Upgraded 'sqlalchemy' from version 2.0.30 to 2.0.31.
- Upgraded 'unstructured' from version 0.14.9 to 0.14.10.
- Upgraded 'chromadb' from version 0.5.3 to 0.5.4.
## [0.3.8] - 2024-07-09 ## [0.3.8] - 2024-07-09
### Added ### Added
......
...@@ -421,7 +421,7 @@ def save_url_image(url): ...@@ -421,7 +421,7 @@ def save_url_image(url):
@app.post("/generations") @app.post("/generations")
def generate_image( async def image_generations(
form_data: GenerateImageForm, form_data: GenerateImageForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
......
...@@ -728,8 +728,10 @@ async def generate_chat_completion( ...@@ -728,8 +728,10 @@ async def generate_chat_completion(
) )
payload = { payload = {
**form_data.model_dump(exclude_none=True), **form_data.model_dump(exclude_none=True, exclude=["metadata"]),
} }
if "metadata" in payload:
del payload["metadata"]
model_id = form_data.model model_id = form_data.model
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
...@@ -741,52 +743,85 @@ async def generate_chat_completion( ...@@ -741,52 +743,85 @@ async def generate_chat_completion(
model_info.params = model_info.params.model_dump() model_info.params = model_info.params.model_dump()
if model_info.params: if model_info.params:
payload["options"] = {} if payload.get("options") is None:
payload["options"] = {}
if model_info.params.get("mirostat", None): if (
model_info.params.get("mirostat", None)
and payload["options"].get("mirostat") is None
):
payload["options"]["mirostat"] = model_info.params.get("mirostat", None) payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
if model_info.params.get("mirostat_eta", None): if (
model_info.params.get("mirostat_eta", None)
and payload["options"].get("mirostat_eta") is None
):
payload["options"]["mirostat_eta"] = model_info.params.get( payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None "mirostat_eta", None
) )
if model_info.params.get("mirostat_tau", None): if (
model_info.params.get("mirostat_tau", None)
and payload["options"].get("mirostat_tau") is None
):
payload["options"]["mirostat_tau"] = model_info.params.get( payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None "mirostat_tau", None
) )
if model_info.params.get("num_ctx", None): if (
model_info.params.get("num_ctx", None)
and payload["options"].get("num_ctx") is None
):
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None) payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
if model_info.params.get("num_batch", None): if (
model_info.params.get("num_batch", None)
and payload["options"].get("num_batch") is None
):
payload["options"]["num_batch"] = model_info.params.get( payload["options"]["num_batch"] = model_info.params.get(
"num_batch", None "num_batch", None
) )
if model_info.params.get("num_keep", None): if (
model_info.params.get("num_keep", None)
and payload["options"].get("num_keep") is None
):
payload["options"]["num_keep"] = model_info.params.get("num_keep", None) payload["options"]["num_keep"] = model_info.params.get("num_keep", None)
if model_info.params.get("repeat_last_n", None): if (
model_info.params.get("repeat_last_n", None)
and payload["options"].get("repeat_last_n") is None
):
payload["options"]["repeat_last_n"] = model_info.params.get( payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None "repeat_last_n", None
) )
if model_info.params.get("frequency_penalty", None): if (
model_info.params.get("frequency_penalty", None)
and payload["options"].get("frequency_penalty") is None
):
payload["options"]["repeat_penalty"] = model_info.params.get( payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None "frequency_penalty", None
) )
if model_info.params.get("temperature", None) is not None: if (
model_info.params.get("temperature", None)
and payload["options"].get("temperature") is None
):
payload["options"]["temperature"] = model_info.params.get( payload["options"]["temperature"] = model_info.params.get(
"temperature", None "temperature", None
) )
if model_info.params.get("seed", None): if (
model_info.params.get("seed", None)
and payload["options"].get("seed") is None
):
payload["options"]["seed"] = model_info.params.get("seed", None) payload["options"]["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None): if (
model_info.params.get("stop", None)
and payload["options"].get("stop") is None
):
payload["options"]["stop"] = ( payload["options"]["stop"] = (
[ [
bytes(stop, "utf-8").decode("unicode_escape") bytes(stop, "utf-8").decode("unicode_escape")
...@@ -796,37 +831,56 @@ async def generate_chat_completion( ...@@ -796,37 +831,56 @@ async def generate_chat_completion(
else None else None
) )
if model_info.params.get("tfs_z", None): if (
model_info.params.get("tfs_z", None)
and payload["options"].get("tfs_z") is None
):
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None) payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if model_info.params.get("max_tokens", None): if (
model_info.params.get("max_tokens", None)
and payload["options"].get("max_tokens") is None
):
payload["options"]["num_predict"] = model_info.params.get( payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None "max_tokens", None
) )
if model_info.params.get("top_k", None): if (
model_info.params.get("top_k", None)
and payload["options"].get("top_k") is None
):
payload["options"]["top_k"] = model_info.params.get("top_k", None) payload["options"]["top_k"] = model_info.params.get("top_k", None)
if model_info.params.get("top_p", None): if (
model_info.params.get("top_p", None)
and payload["options"].get("top_p") is None
):
payload["options"]["top_p"] = model_info.params.get("top_p", None) payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("use_mmap", None): if (
model_info.params.get("use_mmap", None)
and payload["options"].get("use_mmap") is None
):
payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None) payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
if model_info.params.get("use_mlock", None): if (
model_info.params.get("use_mlock", None)
and payload["options"].get("use_mlock") is None
):
payload["options"]["use_mlock"] = model_info.params.get( payload["options"]["use_mlock"] = model_info.params.get(
"use_mlock", None "use_mlock", None
) )
if model_info.params.get("num_thread", None): if (
model_info.params.get("num_thread", None)
and payload["options"].get("num_thread") is None
):
payload["options"]["num_thread"] = model_info.params.get( payload["options"]["num_thread"] = model_info.params.get(
"num_thread", None "num_thread", None
) )
system = model_info.params.get("system", None) system = model_info.params.get("system", None)
if system: if system:
# Check if the payload already has a system message
# If not, add a system message to the payload
system = prompt_template( system = prompt_template(
system, system,
**( **(
...@@ -893,10 +947,10 @@ async def generate_openai_chat_completion( ...@@ -893,10 +947,10 @@ async def generate_openai_chat_completion(
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
form_data = OpenAIChatCompletionForm(**form_data) form_data = OpenAIChatCompletionForm(**form_data)
payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
payload = { if "metadata" in payload:
**form_data.model_dump(exclude_none=True), del payload["metadata"]
}
model_id = form_data.model model_id = form_data.model
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
......
...@@ -21,6 +21,7 @@ from utils.utils import ( ...@@ -21,6 +21,7 @@ from utils.utils import (
get_admin_user, get_admin_user,
) )
from utils.task import prompt_template from utils.task import prompt_template
from utils.misc import add_or_update_system_message
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
...@@ -357,6 +358,8 @@ async def generate_chat_completion( ...@@ -357,6 +358,8 @@ async def generate_chat_completion(
): ):
idx = 0 idx = 0
payload = {**form_data} payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
...@@ -368,24 +371,33 @@ async def generate_chat_completion( ...@@ -368,24 +371,33 @@ async def generate_chat_completion(
model_info.params = model_info.params.model_dump() model_info.params = model_info.params.model_dump()
if model_info.params: if model_info.params:
if model_info.params.get("temperature", None) is not None: if (
model_info.params.get("temperature", None)
and payload.get("temperature") is None
):
payload["temperature"] = float(model_info.params.get("temperature")) payload["temperature"] = float(model_info.params.get("temperature"))
if model_info.params.get("top_p", None): if model_info.params.get("top_p", None) and payload.get("top_p") is None:
payload["top_p"] = int(model_info.params.get("top_p", None)) payload["top_p"] = int(model_info.params.get("top_p", None))
if model_info.params.get("max_tokens", None): if (
model_info.params.get("max_tokens", None)
and payload.get("max_tokens") is None
):
payload["max_tokens"] = int(model_info.params.get("max_tokens", None)) payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
if model_info.params.get("frequency_penalty", None): if (
model_info.params.get("frequency_penalty", None)
and payload.get("frequency_penalty") is None
):
payload["frequency_penalty"] = int( payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None) model_info.params.get("frequency_penalty", None)
) )
if model_info.params.get("seed", None): if model_info.params.get("seed", None) and payload.get("seed") is None:
payload["seed"] = model_info.params.get("seed", None) payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None): if model_info.params.get("stop", None) and payload.get("stop") is None:
payload["stop"] = ( payload["stop"] = (
[ [
bytes(stop, "utf-8").decode("unicode_escape") bytes(stop, "utf-8").decode("unicode_escape")
...@@ -410,21 +422,10 @@ async def generate_chat_completion( ...@@ -410,21 +422,10 @@ async def generate_chat_completion(
else {} else {}
), ),
) )
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"): if payload.get("messages"):
for message in payload["messages"]: payload["messages"] = add_or_update_system_message(
if message.get("role") == "system": system, payload["messages"]
message["content"] = system + message["content"] )
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
else: else:
pass pass
......
...@@ -930,7 +930,9 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)): ...@@ -930,7 +930,9 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
) )
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: def store_data_in_vector_db(
data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False
) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.config.CHUNK_SIZE, chunk_size=app.state.config.CHUNK_SIZE,
...@@ -942,7 +944,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b ...@@ -942,7 +944,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
if len(docs) > 0: if len(docs) > 0:
log.info(f"store_data_in_vector_db {docs}") log.info(f"store_data_in_vector_db {docs}")
return store_docs_in_vector_db(docs, collection_name, overwrite), None return store_docs_in_vector_db(docs, collection_name, metadata, overwrite), None
else: else:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
...@@ -956,14 +958,16 @@ def store_text_in_vector_db( ...@@ -956,14 +958,16 @@ def store_text_in_vector_db(
add_start_index=True, add_start_index=True,
) )
docs = text_splitter.create_documents([text], metadatas=[metadata]) docs = text_splitter.create_documents([text], metadatas=[metadata])
return store_docs_in_vector_db(docs, collection_name, overwrite) return store_docs_in_vector_db(docs, collection_name, overwrite=overwrite)
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool: def store_docs_in_vector_db(
docs, collection_name, metadata: Optional[dict] = None, overwrite: bool = False
) -> bool:
log.info(f"store_docs_in_vector_db {docs} {collection_name}") log.info(f"store_docs_in_vector_db {docs} {collection_name}")
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs] metadatas = [{**doc.metadata, **(metadata if metadata else {})} for doc in docs]
# ChromaDB does not like datetime formats # ChromaDB does not like datetime formats
# for meta-data so convert them to string. # for meta-data so convert them to string.
...@@ -1237,13 +1241,21 @@ def process_doc( ...@@ -1237,13 +1241,21 @@ def process_doc(
data = loader.load() data = loader.load()
try: try:
result = store_data_in_vector_db(data, collection_name) result = store_data_in_vector_db(
data,
collection_name,
{
"file_id": form_data.file_id,
"name": file.meta.get("name", file.filename),
},
)
if result: if result:
return { return {
"status": True, "status": True,
"collection_name": collection_name, "collection_name": collection_name,
"known_type": known_type, "known_type": known_type,
"filename": file.meta.get("name", file.filename),
} }
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
......
...@@ -137,3 +137,34 @@ async def disconnect(sid): ...@@ -137,3 +137,34 @@ async def disconnect(sid):
await sio.emit("user-count", {"count": len(USER_POOL)}) await sio.emit("user-count", {"count": len(USER_POOL)})
else: else:
print(f"Unknown session ID {sid} disconnected") print(f"Unknown session ID {sid} disconnected")
async def get_event_emitter(request_info):
async def __event_emitter__(event_data):
await sio.emit(
"chat-events",
{
"chat_id": request_info["chat_id"],
"message_id": request_info["message_id"],
"data": event_data,
},
to=request_info["session_id"],
)
return __event_emitter__
async def get_event_call(request_info):
async def __event_call__(event_data):
response = await sio.call(
"chat-events",
{
"chat_id": request_info["chat_id"],
"message_id": request_info["message_id"],
"data": event_data,
},
to=request_info["session_id"],
)
return response
return __event_call__
...@@ -20,7 +20,6 @@ from apps.webui.routers import ( ...@@ -20,7 +20,6 @@ from apps.webui.routers import (
) )
from apps.webui.models.functions import Functions from apps.webui.models.functions import Functions
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.utils import load_function_module_by_id from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template from utils.misc import stream_message_template
...@@ -48,12 +47,14 @@ from config import ( ...@@ -48,12 +47,14 @@ from config import (
OAUTH_PICTURE_CLAIM, OAUTH_PICTURE_CLAIM,
) )
from apps.socket.main import get_event_call, get_event_emitter
import inspect import inspect
import uuid import uuid
import time import time
import json import json
from typing import Iterator, Generator from typing import Iterator, Generator, Optional
from pydantic import BaseModel from pydantic import BaseModel
app = FastAPI() app = FastAPI()
...@@ -164,6 +165,10 @@ async def get_pipe_models(): ...@@ -164,6 +165,10 @@ async def get_pipe_models():
f"{function_module.name}{manifold_pipe_name}" f"{function_module.name}{manifold_pipe_name}"
) )
pipe_flag = {"type": pipe.type}
if hasattr(function_module, "ChatValves"):
pipe_flag["valves_spec"] = function_module.ChatValves.schema()
pipe_models.append( pipe_models.append(
{ {
"id": manifold_pipe_id, "id": manifold_pipe_id,
...@@ -171,10 +176,14 @@ async def get_pipe_models(): ...@@ -171,10 +176,14 @@ async def get_pipe_models():
"object": "model", "object": "model",
"created": pipe.created_at, "created": pipe.created_at,
"owned_by": "openai", "owned_by": "openai",
"pipe": {"type": pipe.type}, "pipe": pipe_flag,
} }
) )
else: else:
pipe_flag = {"type": "pipe"}
if hasattr(function_module, "ChatValves"):
pipe_flag["valves_spec"] = function_module.ChatValves.schema()
pipe_models.append( pipe_models.append(
{ {
"id": pipe.id, "id": pipe.id,
...@@ -182,7 +191,7 @@ async def get_pipe_models(): ...@@ -182,7 +191,7 @@ async def get_pipe_models():
"object": "model", "object": "model",
"created": pipe.created_at, "created": pipe.created_at,
"owned_by": "openai", "owned_by": "openai",
"pipe": {"type": "pipe"}, "pipe": pipe_flag,
} }
) )
...@@ -193,6 +202,27 @@ async def generate_function_chat_completion(form_data, user): ...@@ -193,6 +202,27 @@ async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
metadata = None
if "metadata" in form_data:
metadata = form_data["metadata"]
del form_data["metadata"]
__event_emitter__ = None
__event_call__ = None
__task__ = None
if metadata:
if (
metadata.get("session_id")
and metadata.get("chat_id")
and metadata.get("message_id")
):
__event_emitter__ = await get_event_emitter(metadata)
__event_call__ = await get_event_call(metadata)
if metadata.get("task"):
__task__ = metadata.get("task")
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
form_data["model"] = model_info.base_model_id form_data["model"] = model_info.base_model_id
...@@ -307,6 +337,15 @@ async def generate_function_chat_completion(form_data, user): ...@@ -307,6 +337,15 @@ async def generate_function_chat_completion(form_data, user):
params = {**params, "__user__": __user__} params = {**params, "__user__": __user__}
if "__event_emitter__" in sig.parameters:
params = {**params, "__event_emitter__": __event_emitter__}
if "__event_call__" in sig.parameters:
params = {**params, "__event_call__": __event_call__}
if "__task__" in sig.parameters:
params = {**params, "__task__": __task__}
if form_data["stream"]: if form_data["stream"]:
async def stream_content(): async def stream_content():
......
...@@ -167,6 +167,15 @@ class FunctionsTable: ...@@ -167,6 +167,15 @@ class FunctionsTable:
.all() .all()
] ]
def get_global_action_functions(self) -> List[FunctionModel]:
with get_db() as db:
return [
FunctionModel.model_validate(function)
for function in db.query(Function)
.filter_by(type="action", is_active=True, is_global=True)
.all()
]
def get_function_valves_by_id(self, id: str) -> Optional[dict]: def get_function_valves_by_id(self, id: str) -> Optional[dict]:
with get_db() as db: with get_db() as db:
......
...@@ -58,6 +58,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): ...@@ -58,6 +58,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
# replace filename with uuid # replace filename with uuid
id = str(uuid.uuid4()) id = str(uuid.uuid4())
name = filename
filename = f"{id}_{filename}" filename = f"{id}_{filename}"
file_path = f"{UPLOAD_DIR}/{filename}" file_path = f"{UPLOAD_DIR}/{filename}"
...@@ -73,6 +74,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)): ...@@ -73,6 +74,7 @@ def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
"id": id, "id": id,
"filename": filename, "filename": filename,
"meta": { "meta": {
"name": name,
"content_type": file.content_type, "content_type": file.content_type,
"size": len(contents), "size": len(contents),
"path": file_path, "path": file_path,
......
...@@ -79,6 +79,8 @@ def load_function_module_by_id(function_id): ...@@ -79,6 +79,8 @@ def load_function_module_by_id(function_id):
return module.Pipe(), "pipe", frontmatter return module.Pipe(), "pipe", frontmatter
elif hasattr(module, "Filter"): elif hasattr(module, "Filter"):
return module.Filter(), "filter", frontmatter return module.Filter(), "filter", frontmatter
elif hasattr(module, "Action"):
return module.Action(), "action", frontmatter
else: else:
raise Exception("No Function class found") raise Exception("No Function class found")
except Exception as e: except Exception as e:
......
...@@ -95,8 +95,8 @@ class TASKS(str, Enum): ...@@ -95,8 +95,8 @@ class TASKS(str, Enum):
def __str__(self) -> str: def __str__(self) -> str:
return super().__str__() return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'default'}" DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "Title Generation" TITLE_GENERATION = "title_generation"
EMOJI_GENERATION = "Emoji Generation" EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "Query Generation" QUERY_GENERATION = "query_generation"
FUNCTION_CALLING = "Function Calling" FUNCTION_CALLING = "function_calling"
{ {
"version": 0, "version": 0,
"ui": { "ui": {
"default_locale": "en-US", "default_locale": "",
"prompt_suggestions": [ "prompt_suggestions": [
{ {
"title": ["Help me study", "vocabulary for a college entrance exam"], "title": ["Help me study", "vocabulary for a college entrance exam"],
......
...@@ -29,7 +29,7 @@ from starlette.middleware.sessions import SessionMiddleware ...@@ -29,7 +29,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import sio, app as socket_app from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call
from apps.ollama.main import ( from apps.ollama.main import (
app as ollama_app, app as ollama_app,
get_all_models as get_ollama_models, get_all_models as get_ollama_models,
...@@ -317,7 +317,7 @@ async def get_function_call_response( ...@@ -317,7 +317,7 @@ async def get_function_call_response(
{"role": "user", "content": f"Query: {prompt}"}, {"role": "user", "content": f"Query: {prompt}"},
], ],
"stream": False, "stream": False,
"task": TASKS.FUNCTION_CALLING, "task": str(TASKS.FUNCTION_CALLING),
} }
try: try:
...@@ -618,6 +618,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -618,6 +618,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content={"detail": str(e)}, content={"detail": str(e)},
) )
# Extract valves from the request body
valves = None
if "valves" in body:
valves = body["valves"]
del body["valves"]
# Extract session_id, chat_id and message_id from the request body # Extract session_id, chat_id and message_id from the request body
session_id = None session_id = None
if "session_id" in body: if "session_id" in body:
...@@ -632,24 +638,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -632,24 +638,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
message_id = body["id"] message_id = body["id"]
del body["id"] del body["id"]
async def __event_emitter__(data): __event_emitter__ = await get_event_emitter(
await sio.emit( {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
"chat-events", )
{ __event_call__ = await get_event_call(
"chat_id": chat_id, {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
"message_id": message_id, )
"data": data,
},
to=session_id,
)
async def __event_call__(data):
response = await sio.call(
"chat-events",
{"chat_id": chat_id, "message_id": message_id, "data": data},
to=session_id,
)
return response
# Initialize data_items to store additional data to be sent to the client # Initialize data_items to store additional data to be sent to the client
data_items = [] data_items = []
...@@ -703,6 +697,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -703,6 +697,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if len(citations) > 0: if len(citations) > 0:
data_items.append({"citations": citations}) data_items.append({"citations": citations})
body["metadata"] = {
"session_id": session_id,
"chat_id": chat_id,
"message_id": message_id,
"valves": valves,
}
modified_body_bytes = json.dumps(body).encode("utf-8") modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
request._body = modified_body_bytes request._body = modified_body_bytes
...@@ -823,9 +824,6 @@ def filter_pipeline(payload, user): ...@@ -823,9 +824,6 @@ def filter_pipeline(payload, user):
if "detail" in res: if "detail" in res:
raise Exception(r.status_code, res["detail"]) raise Exception(r.status_code, res["detail"])
if "pipeline" not in app.state.MODELS[model_id] and "task" in payload:
del payload["task"]
return payload return payload
...@@ -935,6 +933,7 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION ...@@ -935,6 +933,7 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models(): async def get_all_models():
# TODO: Optimize this function
pipe_models = [] pipe_models = []
openai_models = [] openai_models = []
ollama_models = [] ollama_models = []
...@@ -961,6 +960,14 @@ async def get_all_models(): ...@@ -961,6 +960,14 @@ async def get_all_models():
models = pipe_models + openai_models + ollama_models models = pipe_models + openai_models + ollama_models
global_action_ids = [
function.id for function in Functions.get_global_action_functions()
]
enabled_action_ids = [
function.id
for function in Functions.get_functions_by_type("action", active_only=True)
]
custom_models = Models.get_all_models() custom_models = Models.get_all_models()
for custom_model in custom_models: for custom_model in custom_models:
if custom_model.base_model_id == None: if custom_model.base_model_id == None:
...@@ -971,9 +978,33 @@ async def get_all_models(): ...@@ -971,9 +978,33 @@ async def get_all_models():
): ):
model["name"] = custom_model.name model["name"] = custom_model.name
model["info"] = custom_model.model_dump() model["info"] = custom_model.model_dump()
action_ids = [] + global_action_ids
if "info" in model and "meta" in model["info"]:
action_ids.extend(model["info"]["meta"].get("actionIds", []))
action_ids = list(set(action_ids))
action_ids = [
action_id
for action_id in action_ids
if action_id in enabled_action_ids
]
model["actions"] = []
for action_id in action_ids:
action = Functions.get_function_by_id(action_id)
model["actions"].append(
{
"id": action_id,
"name": action.name,
"description": action.meta.description,
"icon_url": action.meta.manifest.get("icon_url", None),
}
)
else: else:
owned_by = "openai" owned_by = "openai"
pipe = None pipe = None
actions = []
for model in models: for model in models:
if ( if (
...@@ -983,6 +1014,27 @@ async def get_all_models(): ...@@ -983,6 +1014,27 @@ async def get_all_models():
owned_by = model["owned_by"] owned_by = model["owned_by"]
if "pipe" in model: if "pipe" in model:
pipe = model["pipe"] pipe = model["pipe"]
action_ids = [] + global_action_ids
if "info" in model and "meta" in model["info"]:
action_ids.extend(model["info"]["meta"].get("actionIds", []))
action_ids = list(set(action_ids))
action_ids = [
action_id
for action_id in action_ids
if action_id in enabled_action_ids
]
actions = [
{
"id": action_id,
"name": Functions.get_function_by_id(action_id).name,
"description": Functions.get_function_by_id(
action_id
).meta.description,
}
for action_id in action_ids
]
break break
models.append( models.append(
...@@ -995,6 +1047,7 @@ async def get_all_models(): ...@@ -995,6 +1047,7 @@ async def get_all_models():
"info": custom_model.model_dump(), "info": custom_model.model_dump(),
"preset": True, "preset": True,
**({"pipe": pipe} if pipe is not None else {}), **({"pipe": pipe} if pipe is not None else {}),
"actions": actions,
} }
) )
...@@ -1036,13 +1089,24 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u ...@@ -1036,13 +1089,24 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
pipe = model.get("pipe") # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
if pipe: task = None
if "task" in form_data:
task = form_data["task"]
del form_data["task"]
if task:
if "metadata" in form_data:
form_data["metadata"]["task"] = task
else:
form_data["metadata"] = {"task": task}
if model.get("pipe"):
return await generate_function_chat_completion(form_data, user=user) return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
print("generate_ollama_chat_completion")
return await generate_ollama_chat_completion(form_data, user=user) return await generate_ollama_chat_completion(form_data, user=user)
else: else:
return await generate_openai_chat_completion(form_data, user=user) return await generate_openai_chat_completion(form_data, user=user)
...@@ -1107,24 +1171,21 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1107,24 +1171,21 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
else: else:
pass pass
async def __event_emitter__(event_data): __event_emitter__ = await get_event_emitter(
await sio.emit( {
"chat-events", "chat_id": data["chat_id"],
{ "message_id": data["id"],
"chat_id": data["chat_id"], "session_id": data["session_id"],
"message_id": data["id"], }
"data": event_data, )
},
to=data["session_id"],
)
async def __event_call__(event_data): __event_call__ = await get_event_call(
response = await sio.call( {
"chat-events", "chat_id": data["chat_id"],
{"chat_id": data["chat_id"], "message_id": data["id"], "data": event_data}, "message_id": data["id"],
to=data["session_id"], "session_id": data["session_id"],
) }
return response )
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
...@@ -1222,6 +1283,107 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1222,6 +1283,107 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
return data return data
@app.post("/api/chat/actions/{action_id}")
async def chat_completed(
action_id: str, form_data: dict, user=Depends(get_verified_user)
):
action = Functions.get_function_by_id(action_id)
if not action:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Action not found",
)
data = form_data
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
__event_emitter__ = await get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
__event_call__ = await get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
if action_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[action_id]
else:
function_module, _, _ = load_function_module_by_id(action_id)
webui_app.state.FUNCTIONS[action_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
if hasattr(function_module, "action"):
try:
action = function_module.action
# Get the signature of the function
sig = inspect.signature(action)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": action_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
action_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(action):
data = await action(**params)
else:
data = action(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data
################################## ##################################
# #
# Task Endpoints # Task Endpoints
...@@ -1314,7 +1476,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -1314,7 +1476,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"stream": False, "stream": False,
"max_tokens": 50, "max_tokens": 50,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"task": TASKS.TITLE_GENERATION, "task": str(TASKS.TITLE_GENERATION),
} }
log.debug(payload) log.debug(payload)
...@@ -1367,7 +1529,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -1367,7 +1529,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages": [{"role": "user", "content": content}], "messages": [{"role": "user", "content": content}],
"stream": False, "stream": False,
"max_tokens": 30, "max_tokens": 30,
"task": TASKS.QUERY_GENERATION, "task": str(TASKS.QUERY_GENERATION),
} }
print(payload) print(payload)
...@@ -1424,7 +1586,7 @@ Message: """{{prompt}}""" ...@@ -1424,7 +1586,7 @@ Message: """{{prompt}}"""
"stream": False, "stream": False,
"max_tokens": 4, "max_tokens": 4,
"chat_id": form_data.get("chat_id", None), "chat_id": form_data.get("chat_id", None),
"task": TASKS.EMOJI_GENERATION, "task": str(TASKS.EMOJI_GENERATION),
} }
log.debug(payload) log.debug(payload)
......
...@@ -27,7 +27,7 @@ config = context.config ...@@ -27,7 +27,7 @@ config = context.config
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
if config.config_file_name is not None: if config.config_file_name is not None:
fileConfig(config.config_file_name) fileConfig(config.config_file_name, disable_existing_loggers=False)
# add your model's MetaData object here # add your model's MetaData object here
# for 'autogenerate' support # for 'autogenerate' support
......
fastapi==0.111.0 fastapi==0.111.0
uvicorn[standard]==0.22.0 uvicorn[standard]==0.22.0
pydantic==2.7.1 pydantic==2.8.2
python-multipart==0.0.9 python-multipart==0.0.9
Flask==3.0.3 Flask==3.0.3
...@@ -12,7 +12,7 @@ passlib[bcrypt]==1.7.4 ...@@ -12,7 +12,7 @@ passlib[bcrypt]==1.7.4
requests==2.32.3 requests==2.32.3
aiohttp==3.9.5 aiohttp==3.9.5
sqlalchemy==2.0.30 sqlalchemy==2.0.31
alembic==1.13.2 alembic==1.13.2
peewee==3.17.6 peewee==3.17.6
peewee-migrate==1.12.2 peewee-migrate==1.12.2
...@@ -38,12 +38,12 @@ langchain-community==0.2.6 ...@@ -38,12 +38,12 @@ langchain-community==0.2.6
langchain-chroma==0.1.2 langchain-chroma==0.1.2
fake-useragent==1.5.1 fake-useragent==1.5.1
chromadb==0.5.3 chromadb==0.5.4
sentence-transformers==3.0.1 sentence-transformers==3.0.1
pypdf==4.2.0 pypdf==4.2.0
docx2txt==0.8 docx2txt==0.8
python-pptx==0.6.23 python-pptx==0.6.23
unstructured==0.14.9 unstructured==0.14.10
Markdown==3.6 Markdown==3.6
pypandoc==1.13 pypandoc==1.13
pandas==2.2.2 pandas==2.2.2
...@@ -71,7 +71,7 @@ pytube==15.0.0 ...@@ -71,7 +71,7 @@ pytube==15.0.0
extract_msg extract_msg
pydub pydub
duckduckgo-search~=6.1.7 duckduckgo-search~=6.1.12
## Tests ## Tests
docker~=7.1.0 docker~=7.1.0
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.3.8", "version": "0.3.9",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "open-webui", "name": "open-webui",
"version": "0.3.8", "version": "0.3.9",
"dependencies": { "dependencies": {
"@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-javascript": "^6.2.2",
"@codemirror/lang-python": "^6.1.6", "@codemirror/lang-python": "^6.1.6",
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.3.8", "version": "0.3.9",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "npm run pyodide:fetch && vite dev --host", "dev": "npm run pyodide:fetch && vite dev --host",
......
...@@ -104,6 +104,45 @@ export const chatCompleted = async (token: string, body: ChatCompletedForm) => { ...@@ -104,6 +104,45 @@ export const chatCompleted = async (token: string, body: ChatCompletedForm) => {
return res; return res;
}; };
type ChatActionForm = {
model: string;
messages: string[];
chat_id: string;
};
export const chatAction = async (token: string, action_id: string, body: ChatActionForm) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/chat/actions/${action_id}`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify(body)
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = err;
}
return null;
});
if (error) {
throw error;
}
return res;
};
export const getTaskConfig = async (token: string = '') => { export const getTaskConfig = async (token: string = '') => {
let error = null; let error = null;
......
...@@ -425,7 +425,7 @@ export const resetUploadDir = async (token: string) => { ...@@ -425,7 +425,7 @@ export const resetUploadDir = async (token: string) => {
export const resetVectorDB = async (token: string) => { export const resetVectorDB = async (token: string) => {
let error = null; let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/reset`, { const res = await fetch(`${RAG_API_BASE_URL}/reset/db`, {
method: 'GET', method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
......
...@@ -282,6 +282,7 @@ ...@@ -282,6 +282,7 @@
<SensitiveInput <SensitiveInput
placeholder={$i18n.t('Enter api auth string (e.g. username:password)')} placeholder={$i18n.t('Enter api auth string (e.g. username:password)')}
bind:value={AUTOMATIC1111_API_AUTH} bind:value={AUTOMATIC1111_API_AUTH}
required={false}
/> />
<div class="mt-2 text-xs text-gray-400 dark:text-gray-500"> <div class="mt-2 text-xs text-gray-400 dark:text-gray-500">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment