Commit 1c4e7f03 authored by Timothy J. Baek's avatar Timothy J. Baek
Browse files

refac

parent 74a4f642
from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from apps.webui.routers import (
auths,
......@@ -17,6 +18,7 @@ from apps.webui.routers import (
)
from apps.webui.models.functions import Functions
from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template
from config import (
WEBUI_BUILD_HASH,
......@@ -37,6 +39,14 @@ from config import (
AppConfig,
)
import inspect
import uuid
import time
import json
from typing import Iterator, Generator
from pydantic import BaseModel
app = FastAPI()
origins = ["*"]
......@@ -166,3 +176,152 @@ async def get_pipe_models():
)
return pipe_models
async def generate_function_chat_completion(form_data, user):
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
# Check if function is already loaded
if pipe_id not in app.state.FUNCTIONS:
function_module, function_type, frontmatter = load_function_module_by_id(
pipe_id
)
app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe = function_module.pipe
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if form_data["stream"]:
async def stream_content():
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await job()
......@@ -43,7 +43,11 @@ from apps.openai.main import (
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.webui.main import app as webui_app, get_pipe_models
from apps.webui.main import (
app as webui_app,
get_pipe_models,
generate_function_chat_completion,
)
from pydantic import BaseModel
......@@ -228,10 +232,7 @@ async def get_function_call_response(
response = None
try:
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(payload, user=user)
else:
response = await generate_openai_chat_completion(payload, user=user)
response = await generate_chat_completions(form_data=payload, user=user)
content = None
......@@ -900,159 +901,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
pipe = model.get("pipe")
if pipe:
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
# Check if function is already loaded
if pipe_id not in webui_app.state.FUNCTIONS:
function_module, function_type, frontmatter = (
load_function_module_by_id(pipe_id)
)
webui_app.state.FUNCTIONS[pipe_id] = function_module
else:
function_module = webui_app.state.FUNCTIONS[pipe_id]
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(pipe_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
pipe = function_module.pipe
# Get the signature of the function
sig = inspect.signature(pipe)
params = {"body": form_data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
pipe_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if form_data["stream"]:
async def stream_content():
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
return
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(
stream_content(), media_type="text/event-stream"
)
else:
try:
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
except Exception as e:
print(f"Error: {e}")
return {"error": {"detail": str(e)}}
if inspect.iscoroutinefunction(pipe):
res = await pipe(**params)
else:
res = pipe(**params)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await job()
return await generate_function_chat_completion(form_data, user=user)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
......@@ -1334,10 +1183,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(payload, user=user)
else:
return await generate_openai_chat_completion(payload, user=user)
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/query/completions")
......@@ -1397,10 +1243,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(payload, user=user)
else:
return await generate_openai_chat_completion(payload, user=user)
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/emoji/completions")
......@@ -1464,10 +1307,7 @@ Message: """{{prompt}}"""
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(payload, user=user)
else:
return await generate_openai_chat_completion(payload, user=user)
return await generate_chat_completions(form_data=payload, user=user)
@app.post("/api/task/tools/completions")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment