Unverified Commit a9a6ed8b authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #4237 from michaelpoluektov/refactor-webui-main

refactor: Simplify functions
parents 64b41655 e6c64282
......@@ -52,7 +52,6 @@ async def user_join(sid, data):
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
......@@ -80,7 +79,6 @@ def get_models_in_use():
@sio.on("usage")
async def usage(sid, data):
model_id = data["model"]
# Cancel previous callback if there is one
......@@ -139,7 +137,7 @@ async def disconnect(sid):
print(f"Unknown session ID {sid} disconnected")
async def get_event_emitter(request_info):
def get_event_emitter(request_info):
async def __event_emitter__(event_data):
await sio.emit(
"chat-events",
......@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
return __event_emitter__
async def get_event_call(request_info):
def get_event_call(request_info):
async def __event_call__(event_data):
response = await sio.call(
"chat-events",
......
This diff is collapsed.
import json
import logging
from typing import Optional
from typing import Optional, List
from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger, Text
from sqlalchemy import Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS
import time
......@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
class ModelsTable:
def insert_new_model(
self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]:
......@@ -126,9 +123,7 @@ class ModelsTable:
}
)
try:
with get_db() as db:
result = Model(**model.model_dump())
db.add(result)
db.commit()
......@@ -144,13 +139,11 @@ class ModelsTable:
def get_all_models(self) -> List[ModelModel]:
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
with get_db() as db:
model = db.get(Model, id)
return ModelModel.model_validate(model)
except:
......@@ -178,7 +171,6 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool:
try:
with get_db() as db:
db.query(Model).filter_by(id=id).delete()
db.commit()
......
......@@ -13,8 +13,6 @@ import aiohttp
import requests
import mimetypes
import shutil
import os
import uuid
import inspect
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
......@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call
from apps.socket.main import app as socket_app, get_event_emitter, get_event_call
from apps.ollama.main import (
app as ollama_app,
get_all_models as get_ollama_models,
......@@ -619,32 +617,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content={"detail": str(e)},
)
# Extract valves from the request body
valves = None
if "valves" in body:
valves = body["valves"]
del body["valves"]
# Extract session_id, chat_id and message_id from the request body
session_id = None
if "session_id" in body:
session_id = body["session_id"]
del body["session_id"]
chat_id = None
if "chat_id" in body:
chat_id = body["chat_id"]
del body["chat_id"]
message_id = None
if "id" in body:
message_id = body["id"]
del body["id"]
__event_emitter__ = await get_event_emitter(
{"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
)
__event_call__ = await get_event_call(
{"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
)
metadata = {
"chat_id": body.pop("chat_id", None),
"message_id": body.pop("id", None),
"session_id": body.pop("session_id", None),
"valves": body.pop("valves", None),
}
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
# Initialize data_items to store additional data to be sent to the client
data_items = []
......@@ -709,13 +690,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if len(citations) > 0:
data_items.append({"citations": citations})
body["metadata"] = {
"session_id": session_id,
"chat_id": chat_id,
"message_id": message_id,
"valves": valves,
}
body["metadata"] = metadata
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
......@@ -1191,13 +1166,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
status_code=r.status_code,
content=res,
)
except:
except Exception:
pass
else:
pass
__event_emitter__ = await get_event_emitter(
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
......@@ -1205,7 +1180,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
}
)
__event_call__ = await get_event_call(
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
......@@ -1334,14 +1309,14 @@ async def chat_completed(
)
model = app.state.MODELS[model_id]
__event_emitter__ = await get_event_emitter(
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
__event_call__ = await get_event_call(
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
......
from pathlib import Path
import hashlib
import json
import re
from datetime import timedelta
from typing import Optional, List, Tuple
......@@ -8,37 +7,39 @@ import uuid
import time
def get_last_user_message_item(messages: List[dict]) -> str:
def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
for message in reversed(messages):
if message["role"] == "user":
return message
return None
def get_last_user_message(messages: List[dict]) -> str:
message = get_last_user_message_item(messages)
if message is not None:
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
def get_content_from_message(message: dict) -> Optional[str]:
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
else:
return message["content"]
return None
def get_last_assistant_message(messages: List[dict]) -> str:
def get_last_user_message(messages: List[dict]) -> Optional[str]:
message = get_last_user_message_item(messages)
if message is None:
return None
return get_content_from_message(message)
def get_last_assistant_message(messages: List[dict]) -> Optional[str]:
for message in reversed(messages):
if message["role"] == "assistant":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return get_content_from_message(message)
return None
def get_system_message(messages: List[dict]) -> dict:
def get_system_message(messages: List[dict]) -> Optional[dict]:
for message in messages:
if message["role"] == "system":
return message
......@@ -49,7 +50,7 @@ def remove_system_message(messages: List[dict]) -> List[dict]:
return [message for message in messages if message["role"] != "system"]
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
def pop_system_message(messages: List[dict]) -> Tuple[Optional[dict], List[dict]]:
return get_system_message(messages), remove_system_message(messages)
......@@ -87,23 +88,29 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages
def stream_message_template(model: str, message: str):
def openai_chat_message_template(model: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
"choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
}
def openai_chat_chunk_message_template(model: str, message: str):
template = openai_chat_message_template(model)
template["object"] = "chat.completion.chunk"
template["choices"][0]["delta"] = {"content": message}
return template
def openai_chat_completion_message_template(model: str, message: str):
template = openai_chat_message_template(model)
template["object"] = "chat.completion"
template["choices"][0]["message"] = {"content": message, "role": "assistant"}
template["choices"][0]["finish_reason"] = "stop"
def get_gravatar_url(email):
# Trim leading and trailing whitespace from
# an email address and force all characters
......@@ -174,7 +181,7 @@ def extract_folders_after_data_docs(path):
tags = []
folders = parts[index_docs:-1]
for idx, part in enumerate(folders):
for idx, _ in enumerate(folders):
tags.append("/".join(folders[: idx + 1]))
return tags
......@@ -270,11 +277,11 @@ def parse_ollama_modelfile(model_text):
value = param_match.group(1)
try:
if param_type == int:
if param_type is int:
value = int(value)
elif param_type == float:
elif param_type is float:
value = float(value)
elif param_type == bool:
elif param_type is bool:
value = value.lower() == "true"
except Exception as e:
print(e)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment