Commit 3978efd7 authored by Michael Poluektov's avatar Michael Poluektov
Browse files

refac: Refactor functions

parent 9d58bb1c
...@@ -52,7 +52,6 @@ async def user_join(sid, data): ...@@ -52,7 +52,6 @@ async def user_join(sid, data):
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
if user: if user:
SESSION_POOL[sid] = user.id SESSION_POOL[sid] = user.id
if user.id in USER_POOL: if user.id in USER_POOL:
USER_POOL[user.id].append(sid) USER_POOL[user.id].append(sid)
...@@ -80,7 +79,6 @@ def get_models_in_use(): ...@@ -80,7 +79,6 @@ def get_models_in_use():
@sio.on("usage") @sio.on("usage")
async def usage(sid, data): async def usage(sid, data):
model_id = data["model"] model_id = data["model"]
# Cancel previous callback if there is one # Cancel previous callback if there is one
...@@ -139,7 +137,7 @@ async def disconnect(sid): ...@@ -139,7 +137,7 @@ async def disconnect(sid):
print(f"Unknown session ID {sid} disconnected") print(f"Unknown session ID {sid} disconnected")
async def get_event_emitter(request_info): def get_event_emitter(request_info):
async def __event_emitter__(event_data): async def __event_emitter__(event_data):
await sio.emit( await sio.emit(
"chat-events", "chat-events",
...@@ -154,7 +152,7 @@ async def get_event_emitter(request_info): ...@@ -154,7 +152,7 @@ async def get_event_emitter(request_info):
return __event_emitter__ return __event_emitter__
async def get_event_call(request_info): def get_event_call(request_info):
async def __event_call__(event_data): async def __event_call__(event_data):
response = await sio.call( response = await sio.call(
"chat-events", "chat-events",
......
This diff is collapsed.
import json
import logging import logging
from typing import Optional from typing import Optional, List
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger, Text from sqlalchemy import Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db from apps.webui.internal.db import Base, JSONField, get_db
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
import time import time
...@@ -113,7 +111,6 @@ class ModelForm(BaseModel): ...@@ -113,7 +111,6 @@ class ModelForm(BaseModel):
class ModelsTable: class ModelsTable:
def insert_new_model( def insert_new_model(
self, form_data: ModelForm, user_id: str self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]: ) -> Optional[ModelModel]:
...@@ -126,9 +123,7 @@ class ModelsTable: ...@@ -126,9 +123,7 @@ class ModelsTable:
} }
) )
try: try:
with get_db() as db: with get_db() as db:
result = Model(**model.model_dump()) result = Model(**model.model_dump())
db.add(result) db.add(result)
db.commit() db.commit()
...@@ -144,13 +139,11 @@ class ModelsTable: ...@@ -144,13 +139,11 @@ class ModelsTable:
def get_all_models(self) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:
with get_db() as db: with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()] return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
with get_db() as db: with get_db() as db:
model = db.get(Model, id) model = db.get(Model, id)
return ModelModel.model_validate(model) return ModelModel.model_validate(model)
except: except:
...@@ -178,7 +171,6 @@ class ModelsTable: ...@@ -178,7 +171,6 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:
try: try:
with get_db() as db: with get_db() as db:
db.query(Model).filter_by(id=id).delete() db.query(Model).filter_by(id=id).delete()
db.commit() db.commit()
......
...@@ -13,8 +13,6 @@ import aiohttp ...@@ -13,8 +13,6 @@ import aiohttp
import requests import requests
import mimetypes import mimetypes
import shutil import shutil
import os
import uuid
import inspect import inspect
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
...@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware ...@@ -29,7 +27,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import sio, app as socket_app, get_event_emitter, get_event_call from apps.socket.main import app as socket_app, get_event_emitter, get_event_call
from apps.ollama.main import ( from apps.ollama.main import (
app as ollama_app, app as ollama_app,
get_all_models as get_ollama_models, get_all_models as get_ollama_models,
...@@ -639,10 +637,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -639,10 +637,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
message_id = body["id"] message_id = body["id"]
del body["id"] del body["id"]
__event_emitter__ = await get_event_emitter( __event_emitter__ = get_event_emitter(
{"chat_id": chat_id, "message_id": message_id, "session_id": session_id} {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
) )
__event_call__ = await get_event_call( __event_call__ = get_event_call(
{"chat_id": chat_id, "message_id": message_id, "session_id": session_id} {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
) )
...@@ -1191,13 +1189,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1191,13 +1189,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
status_code=r.status_code, status_code=r.status_code,
content=res, content=res,
) )
except: except Exception:
pass pass
else: else:
pass pass
__event_emitter__ = await get_event_emitter( __event_emitter__ = get_event_emitter(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
...@@ -1205,7 +1203,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)): ...@@ -1205,7 +1203,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
} }
) )
__event_call__ = await get_event_call( __event_call__ = get_event_call(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
...@@ -1334,14 +1332,14 @@ async def chat_completed( ...@@ -1334,14 +1332,14 @@ async def chat_completed(
) )
model = app.state.MODELS[model_id] model = app.state.MODELS[model_id]
__event_emitter__ = await get_event_emitter( __event_emitter__ = get_event_emitter(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
"session_id": data["session_id"], "session_id": data["session_id"],
} }
) )
__event_call__ = await get_event_call( __event_call__ = get_event_call(
{ {
"chat_id": data["chat_id"], "chat_id": data["chat_id"],
"message_id": data["id"], "message_id": data["id"],
...@@ -1770,7 +1768,6 @@ class AddPipelineForm(BaseModel): ...@@ -1770,7 +1768,6 @@ class AddPipelineForm(BaseModel):
@app.post("/api/pipelines/add") @app.post("/api/pipelines/add")
async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)): async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user)):
r = None r = None
try: try:
urlIdx = form_data.urlIdx urlIdx = form_data.urlIdx
...@@ -1813,7 +1810,6 @@ class DeletePipelineForm(BaseModel): ...@@ -1813,7 +1810,6 @@ class DeletePipelineForm(BaseModel):
@app.delete("/api/pipelines/delete") @app.delete("/api/pipelines/delete")
async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)): async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_user)):
r = None r = None
try: try:
urlIdx = form_data.urlIdx urlIdx = form_data.urlIdx
...@@ -1891,7 +1887,6 @@ async def get_pipeline_valves( ...@@ -1891,7 +1887,6 @@ async def get_pipeline_valves(
models = await get_all_models() models = await get_all_models()
r = None r = None
try: try:
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment