Unverified Commit 9bcd4ce5 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3559 from open-webui/dev

0.3.8
parents 824966ad b38abf23
......@@ -2,13 +2,10 @@ import json
import logging
from typing import Optional
import peewee as pw
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import DB, JSONField
from apps.webui.internal.db import Base, JSONField, get_db
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS
......@@ -32,7 +29,7 @@ class ModelParams(BaseModel):
# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/favicon.png"
profile_image_url: Optional[str] = "/static/favicon.png"
description: Optional[str] = None
"""
......@@ -46,38 +43,37 @@ class ModelMeta(BaseModel):
pass
class Model(pw.Model):
id = pw.TextField(unique=True)
class Model(Base):
__tablename__ = "model"
id = Column(Text, primary_key=True)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id = pw.TextField()
user_id = Column(Text)
base_model_id = pw.TextField(null=True)
base_model_id = Column(Text, nullable=True)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name = pw.TextField()
name = Column(Text)
"""
The human-readable display name of the model.
"""
params = JSONField()
params = Column(JSONField)
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta = JSONField()
meta = Column(JSONField)
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class ModelModel(BaseModel):
......@@ -92,6 +88,8 @@ class ModelModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
......@@ -115,12 +113,6 @@ class ModelForm(BaseModel):
class ModelsTable:
def __init__(
self,
db: pw.SqliteDatabase | pw.PostgresqlDatabase,
):
self.db = db
self.db.create_tables([Model])
def insert_new_model(
self, form_data: ModelForm, user_id: str
......@@ -134,34 +126,50 @@ class ModelsTable:
}
)
try:
result = Model.create(**model.model_dump())
if result:
return model
else:
return None
with get_db() as db:
result = Model(**model.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ModelModel.model_validate(result)
else:
return None
except Exception as e:
print(e)
return None
def get_all_models(self) -> List[ModelModel]:
return [ModelModel(**model_to_dict(model)) for model in Model.select()]
with get_db() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try:
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
with get_db() as db:
model = db.get(Model, id)
return ModelModel.model_validate(model)
except:
return None
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
# update only the fields that are present in the model
query = Model.update(**model.model_dump()).where(Model.id == id)
query.execute()
model = Model.get(Model.id == id)
return ModelModel(**model_to_dict(model))
with get_db() as db:
# update only the fields that are present in the model
result = (
db.query(Model)
.filter_by(id=id)
.update(model.model_dump(exclude={"id"}, exclude_none=True))
)
db.commit()
model = db.get(Model, id)
db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e:
print(e)
......@@ -169,11 +177,14 @@ class ModelsTable:
def delete_model_by_id(self, id: str) -> bool:
try:
query = Model.delete().where(Model.id == id)
query.execute()
return True
with get_db() as db:
db.query(Model).filter_by(id=id).delete()
db.commit()
return True
except:
return False
Models = ModelsTable(DB)
Models = ModelsTable()
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import DB
from apps.webui.internal.db import Base, get_db
import json
......@@ -16,15 +13,14 @@ import json
####################
class Prompt(Model):
command = CharField(unique=True)
user_id = CharField()
title = TextField()
content = TextField()
timestamp = BigIntegerField()
class Prompt(Base):
__tablename__ = "prompt"
class Meta:
database = DB
command = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
content = Column(Text)
timestamp = Column(BigInteger)
class PromptModel(BaseModel):
......@@ -34,6 +30,8 @@ class PromptModel(BaseModel):
content: str
timestamp: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
......@@ -48,10 +46,6 @@ class PromptForm(BaseModel):
class PromptsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt(
self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]:
......@@ -66,53 +60,60 @@ class PromptsTable:
)
try:
result = Prompt.create(**prompt.model_dump())
if result:
return prompt
else:
return None
except:
with get_db() as db:
result = Prompt(**prompt.dict())
db.add(result)
db.commit()
db.refresh(result)
if result:
return PromptModel.model_validate(result)
else:
return None
except Exception as e:
return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try:
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first()
return PromptModel.model_validate(prompt)
except:
return None
def get_prompts(self) -> List[PromptModel]:
return [
PromptModel(**model_to_dict(prompt))
for prompt in Prompt.select()
# .limit(limit).offset(skip)
]
with get_db() as db:
return [
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm
) -> Optional[PromptModel]:
try:
query = Prompt.update(
title=form_data.title,
content=form_data.content,
timestamp=int(time.time()),
).where(Prompt.command == command)
query.execute()
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
with get_db() as db:
prompt = db.query(Prompt).filter_by(command=command).first()
prompt.title = form_data.title
prompt.content = form_data.content
prompt.timestamp = int(time.time())
db.commit()
return PromptModel.model_validate(prompt)
except:
return None
def delete_prompt_by_command(self, command: str) -> bool:
try:
query = Prompt.delete().where((Prompt.command == command))
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Prompt).filter_by(command=command).delete()
db.commit()
return True
return True
except:
return False
Prompts = PromptsTable(DB)
Prompts = PromptsTable()
from pydantic import BaseModel
from typing import List, Union, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import json
import uuid
import time
import logging
from apps.webui.internal.db import DB
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS
......@@ -20,25 +20,23 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Tag(Model):
id = CharField(unique=True)
name = CharField()
user_id = CharField()
data = TextField(null=True)
class Tag(Base):
__tablename__ = "tag"
class Meta:
database = DB
id = Column(String, primary_key=True)
name = Column(String)
user_id = Column(String)
data = Column(Text, nullable=True)
class ChatIdTag(Model):
id = CharField(unique=True)
tag_name = CharField()
chat_id = CharField()
user_id = CharField()
timestamp = BigIntegerField()
class ChatIdTag(Base):
__tablename__ = "chatidtag"
class Meta:
database = DB
id = Column(String, primary_key=True)
tag_name = Column(String)
chat_id = Column(String)
user_id = Column(String)
timestamp = Column(BigInteger)
class TagModel(BaseModel):
......@@ -47,6 +45,8 @@ class TagModel(BaseModel):
user_id: str
data: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
class ChatIdTagModel(BaseModel):
id: str
......@@ -55,6 +55,8 @@ class ChatIdTagModel(BaseModel):
user_id: str
timestamp: int
model_config = ConfigDict(from_attributes=True)
####################
# Forms
......@@ -75,28 +77,31 @@ class ChatTagsResponse(BaseModel):
class TagTable:
def __init__(self, db):
self.db = db
db.create_tables([Tag, ChatIdTag])
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
result = Tag.create(**tag.model_dump())
if result:
return tag
else:
with get_db() as db:
id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
result = Tag(**tag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return TagModel.model_validate(result)
else:
return None
except Exception as e:
return None
except Exception as e:
return None
def get_tag_by_name_and_user_id(
self, name: str, user_id: str
) -> Optional[TagModel]:
try:
tag = Tag.get(Tag.name == name, Tag.user_id == user_id)
return TagModel(**model_to_dict(tag))
with get_db() as db:
tag = db.query(Tag).filter(name=name, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception as e:
return None
......@@ -118,82 +123,110 @@ class TagTable:
}
)
try:
result = ChatIdTag.create(**chatIdTag.model_dump())
if result:
return chatIdTag
else:
return None
with get_db() as db:
result = ChatIdTag(**chatIdTag.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ChatIdTagModel.model_validate(result)
else:
return None
except:
return None
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where(ChatIdTag.user_id == user_id)
.order_by(ChatIdTag.timestamp.desc())
]
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select()
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names))
]
with get_db() as db:
tag_names = [
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
]
def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str
) -> List[TagModel]:
tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id))
.order_by(ChatIdTag.timestamp.desc())
]
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select()
.where(Tag.user_id == user_id)
.where(Tag.name.in_(tag_names))
]
with get_db() as db:
tag_names = [
chat_id_tag.tag_name
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, chat_id=chat_id)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag)
.filter_by(user_id=user_id)
.filter(Tag.name.in_(tag_names))
.all()
)
]
def get_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> Optional[ChatIdTagModel]:
return [
ChatIdTagModel(**model_to_dict(chat_id_tag))
for chat_id_tag in ChatIdTag.select()
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name))
.order_by(ChatIdTag.timestamp.desc())
]
) -> List[ChatIdTagModel]:
with get_db() as db:
return [
ChatIdTagModel.model_validate(chat_id_tag)
for chat_id_tag in (
db.query(ChatIdTag)
.filter_by(user_id=user_id, tag_name=tag_name)
.order_by(ChatIdTag.timestamp.desc())
.all()
)
]
def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> int:
return (
ChatIdTag.select()
.where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id))
.count()
)
with get_db() as db:
return (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.count()
)
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try:
query = ChatIdTag.delete().where(
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
# Remove tag item from Tag col as well
query = Tag.delete().where(
(Tag.name == tag_name) & (Tag.user_id == user_id)
with get_db() as db:
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.delete()
)
query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
db.commit()
return True
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
......@@ -202,23 +235,25 @@ class TagTable:
self, tag_name: str, chat_id: str, user_id: str
) -> bool:
try:
query = ChatIdTag.delete().where(
(ChatIdTag.tag_name == tag_name)
& (ChatIdTag.chat_id == chat_id)
& (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
# Remove tag item from Tag col as well
query = Tag.delete().where(
(Tag.name == tag_name) & (Tag.user_id == user_id)
with get_db() as db:
res = (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
.delete()
)
log.debug(f"res: {res}")
db.commit()
tag_count = self.count_chat_ids_by_tag_name_and_user_id(
tag_name, user_id
)
query.execute() # Remove the rows, return number of rows removed.
if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
db.commit()
return True
return True
except Exception as e:
log.error(f"delete_tag: {e}")
return False
......@@ -234,4 +269,4 @@ class TagTable:
return True
Tags = TagTable(DB)
Tags = TagTable()
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import Base, JSONField, get_db
from apps.webui.models.users import Users
import json
......@@ -21,19 +21,18 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Tool(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
valves = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Tool(Base):
__tablename__ = "tool"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(Text)
content = Column(Text)
specs = Column(JSONField)
meta = Column(JSONField)
valves = Column(JSONField)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class ToolMeta(BaseModel):
......@@ -51,6 +50,8 @@ class ToolModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
......@@ -78,61 +79,68 @@ class ToolValves(BaseModel):
class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool.create(**tool.model_dump())
if result:
return tool
else:
with get_db() as db:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool(**tool.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return ToolModel.model_validate(result)
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
with get_db() as db:
tool = db.get(Tool, id)
return ToolModel.model_validate(tool)
except:
return None
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
with get_db() as db:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try:
tool = Tool.get(Tool.id == id)
return tool.valves if tool.valves else {}
with get_db() as db:
tool = db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try:
query = Tool.update(
**{"valves": valves},
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolValves(**model_to_dict(tool))
with get_db() as db:
db.query(Tool).filter_by(id=id).update(
{"valves": valves, "updated_at": int(time.time())}
)
db.commit()
return self.get_tool_by_id(id)
except:
return None
......@@ -141,7 +149,7 @@ class ToolsTable:
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
......@@ -159,7 +167,7 @@ class ToolsTable:
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
......@@ -170,8 +178,7 @@ class ToolsTable:
user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
Users.update_user_by_id(user_id, {"settings": user_settings})
return user_settings["tools"]["valves"][id]
except Exception as e:
......@@ -180,25 +187,27 @@ class ToolsTable:
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
**updated,
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
with get_db() as db:
db.query(Tool).filter_by(id=id).update(
{**updated, "updated_at": int(time.time())}
)
db.commit()
tool = db.query(Tool).get(id)
db.refresh(tool)
return ToolModel.model_validate(tool)
except:
return None
def delete_tool_by_id(self, id: str) -> bool:
try:
query = Tool.delete().where((Tool.id == id))
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Tool).filter_by(id=id).delete()
db.commit()
return True
return True
except:
return False
Tools = ToolsTable(DB)
Tools = ToolsTable()
from pydantic import BaseModel, ConfigDict
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict, parse_obj_as
from typing import List, Union, Optional
import time
from sqlalchemy import String, Column, BigInteger, Text
from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB, JSONField
from apps.webui.internal.db import Base, JSONField, Session, get_db
from apps.webui.models.chats import Chats
####################
......@@ -13,25 +14,24 @@ from apps.webui.models.chats import Chats
####################
class User(Model):
id = CharField(unique=True)
name = CharField()
email = CharField()
role = CharField()
profile_image_url = TextField()
class User(Base):
__tablename__ = "user"
last_active_at = BigIntegerField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
id = Column(String, primary_key=True)
name = Column(String)
email = Column(String)
role = Column(String)
profile_image_url = Column(Text)
api_key = CharField(null=True, unique=True)
settings = JSONField(null=True)
info = JSONField(null=True)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
oauth_sub = TextField(null=True, unique=True)
api_key = Column(String, nullable=True, unique=True)
settings = Column(JSONField, nullable=True)
info = Column(JSONField, nullable=True)
class Meta:
database = DB
oauth_sub = Column(Text, unique=True)
class UserSettings(BaseModel):
......@@ -57,6 +57,8 @@ class UserModel(BaseModel):
oauth_sub: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
####################
# Forms
......@@ -76,9 +78,6 @@ class UserUpdateForm(BaseModel):
class UsersTable:
def __init__(self, db):
self.db = db
self.db.create_tables([User])
def insert_new_user(
self,
......@@ -89,77 +88,92 @@ class UsersTable:
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
user = UserModel(
**{
"id": id,
"name": name,
"email": email,
"role": role,
"profile_image_url": profile_image_url,
"last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
"oauth_sub": oauth_sub,
}
)
result = User.create(**user.model_dump())
if result:
return user
else:
return None
with get_db() as db:
user = UserModel(
**{
"id": id,
"name": name,
"email": email,
"role": role,
"profile_image_url": profile_image_url,
"last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
"oauth_sub": oauth_sub,
}
)
result = User(**user.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return user
else:
return None
def get_user_by_id(self, id: str) -> Optional[UserModel]:
try:
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except Exception as e:
return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try:
user = User.get(User.api_key == api_key)
return UserModel(**model_to_dict(user))
with get_db() as db:
user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
except:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]:
try:
user = User.get(User.email == email)
return UserModel(**model_to_dict(user))
with get_db() as db:
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
except:
return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try:
user = User.get(User.oauth_sub == sub)
return UserModel(**model_to_dict(user))
with get_db() as db:
user = db.query(User).filter_by(oauth_sub=sub).first()
return UserModel.model_validate(user)
except:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [
UserModel(**model_to_dict(user))
for user in User.select()
# .limit(limit).offset(skip)
]
with get_db() as db:
users = (
db.query(User)
# .offset(skip).limit(limit)
.all()
)
return [UserModel.model_validate(user) for user in users]
def get_num_users(self) -> Optional[int]:
return User.select().count()
with get_db() as db:
return db.query(User).count()
def get_first_user(self) -> UserModel:
try:
user = User.select().order_by(User.created_at).first()
return UserModel(**model_to_dict(user))
with get_db() as db:
user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user)
except:
return None
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
try:
query = User.update(role=role).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
with get_db() as db:
db.query(User).filter_by(id=id).update({"role": role})
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
......@@ -167,23 +181,28 @@ class UsersTable:
self, id: str, profile_image_url: str
) -> Optional[UserModel]:
try:
query = User.update(profile_image_url=profile_image_url).where(
User.id == id
)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
with get_db() as db:
db.query(User).filter_by(id=id).update(
{"profile_image_url": profile_image_url}
)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try:
query = User.update(last_active_at=int(time.time())).where(User.id == id)
query.execute()
with get_db() as db:
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())}
)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
......@@ -191,22 +210,25 @@ class UsersTable:
self, id: str, oauth_sub: str
) -> Optional[UserModel]:
try:
query = User.update(oauth_sub=oauth_sub).where(User.id == id)
query.execute()
with get_db() as db:
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
db.commit()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
with get_db() as db:
db.query(User).filter_by(id=id).update(updated)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
# return UserModel(**user.dict())
except Exception as e:
return None
def delete_user_by_id(self, id: str) -> bool:
......@@ -215,9 +237,10 @@ class UsersTable:
result = Chats.delete_chats_by_user_id(id)
if result:
# Delete User
query = User.delete().where(User.id == id)
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
# Delete User
db.query(User).filter_by(id=id).delete()
db.commit()
return True
else:
......@@ -227,19 +250,20 @@ class UsersTable:
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try:
query = User.update(api_key=api_key).where(User.id == id)
result = query.execute()
return True if result == 1 else False
with get_db() as db:
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit()
return True if result == 1 else False
except:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try:
user = User.get(User.id == id)
return user.api_key
except:
with get_db() as db:
user = db.query(User).filter_by(id=id).first()
return user.api_key
except Exception as e:
return None
Users = UsersTable(DB)
Users = UsersTable()
......@@ -76,7 +76,10 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user
@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
user_id: str,
user=Depends(get_admin_user),
skip: int = 0,
limit: int = 50,
):
return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit
......@@ -119,7 +122,7 @@ async def get_user_chats(user=Depends(get_verified_user)):
@router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_verified_user)):
async def get_user_archived_chats(user=Depends(get_verified_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id)
......@@ -207,7 +210,6 @@ async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_verified_user)
):
print(form_data)
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
......
......@@ -130,7 +130,9 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_
@router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
name: str,
form_data: DocumentUpdateForm,
user=Depends(get_admin_user),
):
doc = Documents.update_doc_by_name(name, form_data)
if doc:
......
......@@ -50,10 +50,7 @@ router = APIRouter()
@router.post("/")
def upload_file(
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
log.info(f"file.content_type: {file.content_type}")
try:
unsanitized_filename = file.filename
......
......@@ -233,7 +233,10 @@ async def delete_function_by_id(
# delete the function file
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
os.remove(function_path)
try:
os.remove(function_path)
except:
pass
return result
......
......@@ -50,7 +50,9 @@ class MemoryUpdateModel(BaseModel):
@router.post("/add", response_model=Optional[MemoryModel])
async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
request: Request,
form_data: AddMemoryForm,
user=Depends(get_verified_user),
):
memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
......
......@@ -5,6 +5,7 @@ from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user
......@@ -29,7 +30,9 @@ async def get_models(user=Depends(get_verified_user)):
@router.post("/add", response_model=Optional[ModelModel])
async def add_new_model(
request: Request, form_data: ModelForm, user=Depends(get_admin_user)
request: Request,
form_data: ModelForm,
user=Depends(get_admin_user),
):
if form_data.id in request.app.state.MODELS:
raise HTTPException(
......@@ -73,7 +76,10 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/update", response_model=Optional[ModelModel])
async def update_model_by_id(
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
request: Request,
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
):
model = Models.get_model_by_id(id)
if model:
......
......@@ -71,7 +71,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
@router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_admin_user)
command: str,
form_data: PromptForm,
user=Depends(get_admin_user),
):
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:
......
......@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
......@@ -57,7 +56,9 @@ async def get_toolkits(user=Depends(get_admin_user)):
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
request: Request,
form_data: ToolForm,
user=Depends(get_admin_user),
):
if not form_data.id.isidentifier():
raise HTTPException(
......@@ -131,7 +132,10 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
request: Request,
id: str,
form_data: ToolForm,
user=Depends(get_admin_user),
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
......
......@@ -138,7 +138,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/info/update", response_model=Optional[dict])
async def update_user_settings_by_session_user(
async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user)
):
user = Users.get_user_by_id(user.id)
......@@ -205,7 +205,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
user_id: str,
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
):
user = Users.get_user_by_id(user_id)
......
from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status
from peewee import SqliteDatabase
from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel
......@@ -10,7 +9,6 @@ import markdown
import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url
......@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if not isinstance(DB, SqliteDatabase):
from apps.webui.internal.db import engine
if engine.name != "sqlite":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
)
return FileResponse(
DB.database,
engine.url.database,
media_type="application/octet-stream",
filename="webui.db",
)
......
......@@ -5,9 +5,8 @@ import importlib.metadata
import pkgutil
import chromadb
from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup
from typing import TypeVar, Generic, Union
from typing import TypeVar, Generic
from pydantic import BaseModel
from typing import Optional
......@@ -19,7 +18,6 @@ import markdown
import requests
import shutil
from secrets import token_bytes
from constants import ERROR_MESSAGES
####################################
......@@ -395,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
)
OAUTH_USERNAME_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.username_claim",
os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
)
OAUTH_PICTURE_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.avatar_claim",
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()
......@@ -440,16 +450,27 @@ load_oauth_providers()
STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png"
frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png"
if frontend_favicon.exists():
try:
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png"
if frontend_splash.exists():
try:
shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend splash not found at {frontend_splash}")
####################################
# CUSTOM_NAME
####################################
......@@ -474,6 +495,19 @@ if CUSTOM_NAME:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
if "splash" in data:
url = (
f"https://api.openwebui.com{data['splash']}"
if data["splash"][0] == "/"
else data["splash"]
)
r = requests.get(url, stream=True)
if r.status_code == 200:
with open(f"{STATIC_DIR}/splash.png", "wb") as f:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
WEBUI_NAME = data["name"]
except Exception as e:
log.exception(e)
......@@ -769,11 +803,14 @@ class BannerModel(BaseModel):
timestamp: int
WEBUI_BANNERS = PersistentConfig(
"WEBUI_BANNERS",
"ui.banners",
[BannerModel(**banner) for banner in json.loads("[]")],
)
try:
banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]"))
banners = [BannerModel(**banner) for banner in banners]
except Exception as e:
print(f"Error loading WEBUI_BANNERS: {e}")
banners = []
WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners)
SHOW_ADMIN_DETAILS = PersistentConfig(
......@@ -885,6 +922,22 @@ WEBUI_SESSION_COOKIE_SECURE = os.environ.get(
if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
####################################
# RAG document content extraction
####################################
CONTENT_EXTRACTION_ENGINE = PersistentConfig(
"CONTENT_EXTRACTION_ENGINE",
"rag.CONTENT_EXTRACTION_ENGINE",
os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(),
)
TIKA_SERVER_URL = PersistentConfig(
"TIKA_SERVER_URL",
"rag.tika_server_url",
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
)
####################################
# RAG
####################################
......@@ -1302,3 +1355,7 @@ AUDIO_TTS_VOICE = PersistentConfig(
####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
# Replace the postgres:// with postgresql://
if "postgres://" in DATABASE_URL:
DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://")
......@@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum):
OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature."
)
class TASKS(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'default'}"
TITLE_GENERATION = "Title Generation"
EMOJI_GENERATION = "Emoji Generation"
QUERY_GENERATION = "Query Generation"
FUNCTION_CALLING = "Function Calling"
......@@ -4,9 +4,7 @@ from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from bs4 import BeautifulSoup
import json
import markdown
import time
import os
import sys
......@@ -18,25 +16,22 @@ import shutil
import os
import uuid
import inspect
import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.responses import StreamingResponse, Response, RedirectResponse
from apps.socket.main import app as socket_app
from apps.socket.main import sio, app as socket_app
from apps.ollama.main import (
app as ollama_app,
OpenAIChatCompletionForm,
get_all_models as get_ollama_models,
generate_openai_chat_completion as generate_ollama_chat_completion,
)
......@@ -54,13 +49,14 @@ from apps.webui.main import (
get_pipe_models,
generate_function_chat_completion,
)
from apps.webui.internal.db import Session
from pydantic import BaseModel
from typing import List, Optional, Iterator, Generator, Union
from typing import List, Optional
from apps.webui.models.auths import Auths
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.models import Models
from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions
from apps.webui.models.users import Users
......@@ -83,14 +79,12 @@ from utils.task import (
from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
parse_duration,
)
from apps.rag.utils import get_rag_context, rag_template
from config import (
CONFIG_DATA,
WEBUI_NAME,
WEBUI_URL,
WEBUI_AUTH,
......@@ -98,7 +92,6 @@ from config import (
VERSION,
CHANGELOG,
FRONTEND_BUILD_DIR,
UPLOAD_DIR,
CACHE_DIR,
STATIC_DIR,
DEFAULT_LOCALE,
......@@ -126,7 +119,8 @@ from config import (
WEBUI_SESSION_COOKIE_SECURE,
AppConfig,
)
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook
if SAFE_MODE:
......@@ -167,8 +161,20 @@ https://github.com/open-webui/open-webui
)
def run_migrations():
try:
from alembic.config import Config
from alembic import command
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
except Exception as e:
print(f"Error: {e}")
@asynccontextmanager
async def lifespan(app: FastAPI):
run_migrations()
yield
......@@ -212,8 +218,79 @@ origins = ["*"]
##################################
async def get_body_and_model_and_user(request):
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {}
model_id = body["model"]
if model_id not in app.state.MODELS:
raise Exception("Model not found")
model = app.state.MODELS[model_id]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
return body, model, user
def get_task_model_id(default_model_id):
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
return task_model_id
def get_filter_function_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
return filter_ids
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user
messages,
files,
tool_id,
template,
task_model_id,
user,
__event_emitter__=None,
__event_call__=None,
):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
......@@ -240,6 +317,7 @@ async def get_function_call_response(
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
"task": TASKS.FUNCTION_CALLING,
}
try:
......@@ -252,7 +330,6 @@ async def get_function_call_response(
response = None
try:
response = await generate_chat_completions(form_data=payload, user=user)
content = None
if hasattr(response, "body_iterator"):
......@@ -266,334 +343,367 @@ async def get_function_call_response(
else:
content = response["choices"][0]["message"]["content"]
if content is None:
return None, None, False
# Parse the function response
if content is not None:
print(f"content: {content}")
result = json.loads(content)
print(result)
citation = None
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, frontmatter = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(
toolkit_module, "Valves"
):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(
**(valves if valves else {})
)
print(f"content: {content}")
result = json.loads(content)
print(result)
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
params = result["parameters"]
citation = None
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(
tool_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__messages__" in sig.parameters:
# Call the function with the '__messages__' parameter included
params = {
**params,
"__messages__": messages,
}
if "__files__" in sig.parameters:
# Call the function with the '__files__' parameter included
params = {
**params,
"__files__": files,
}
if "__model__" in sig.parameters:
# Call the function with the '__model__' parameter included
params = {
**params,
"__model__": model,
}
if "__id__" in sig.parameters:
# Call the function with the '__id__' parameter included
params = {
**params,
"__id__": tool_id,
}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
if "name" not in result:
return None, None, False
# Call the function
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module, _ = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
if hasattr(toolkit_module, "valves") and hasattr(toolkit_module, "Valves"):
valves = Tools.get_tool_valves_by_id(tool_id)
toolkit_module.valves = toolkit_module.Valves(**(valves if valves else {}))
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
params = result["parameters"]
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": tool_id,
"__messages__": messages,
"__files__": files,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(toolkit_module, "UserValves"):
__user__["valves"] = toolkit_module.UserValves(
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
)
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result is not None:
return function_result, citation, file_handler
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else:
function_result = function(**params)
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result is not None:
return function_result, citation, file_handler
except Exception as e:
print(f"Error: {e}")
return None, None, False
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
data_items = []
async def chat_completion_functions_handler(
body, model, user, __event_emitter__, __event_call__
):
skip_files = None
filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not hasattr(function_module, "inlet"):
continue
try:
inlet = function_module.inlet
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": body}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
body = inlet(**params)
except Exception as e:
print(f"Error: {e}")
raise e
if skip_files:
if "files" in body:
del body["files"]
return body, {}
async def chat_completion_tools_handler(body, user, __event_emitter__, __event_call__):
skip_files = None
contexts = []
citations = None
task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions
if "tool_ids" in body:
print(body["tool_ids"])
for tool_id in body["tool_ids"]:
print(tool_id)
try:
response, citation, file_handler = await get_function_call_response(
messages=body["messages"],
files=body.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
__event_emitter__=__event_emitter__,
__event_call__=__event_call__,
)
show_citations = False
citations = []
print(file_handler)
if isinstance(response, str):
contexts.append(response)
if citation:
if citations is None:
citations = [citation]
else:
citations.append(citation)
if file_handler:
skip_files = True
except Exception as e:
print(f"Error: {e}")
del body["tool_ids"]
print(f"tool_contexts: {contexts}")
if skip_files:
if "files" in body:
del body["files"]
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
async def chat_completion_files_handler(body):
contexts = []
citations = None
if "files" in body:
files = body["files"]
del body["files"]
contexts, citations = get_rag_context(
files=files,
messages=body["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
):
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
data = json.loads(body_str) if body_str else {}
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
if data.get("citations"):
show_citations = True
del data["citations"]
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
try:
body, model, user = await get_body_and_model_and_user(request)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
model = app.state.MODELS[model_id]
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get(
"priority", 0
)
return 0
filter_ids = [
function.id for function in Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type(
"filter", active_only=True
# Extract session_id, chat_id and message_id from the request body
session_id = None
if "session_id" in body:
session_id = body["session_id"]
del body["session_id"]
chat_id = None
if "chat_id" in body:
chat_id = body["chat_id"]
del body["chat_id"]
message_id = None
if "id" in body:
message_id = body["id"]
del body["id"]
async def __event_emitter__(data):
await sio.emit(
"chat-events",
{
"chat_id": chat_id,
"message_id": message_id,
"data": data,
},
to=session_id,
)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
async def __event_call__(data):
response = await sio.call(
"chat-events",
{"chat_id": chat_id, "message_id": message_id, "data": data},
to=session_id,
)
return response
try:
if hasattr(function_module, "inlet"):
inlet = function_module.inlet
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": data}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if inspect.iscoroutinefunction(inlet):
data = await inlet(**params)
else:
data = inlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Initialize data_items to store additional data to be sent to the client
data_items = []
# Set the task model
task_model_id = data["model"]
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"])
context = ""
# If tool_ids field is present, call the functions
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
print(tool_id)
try:
response, citation, file_handler = (
await get_function_call_response(
messages=data["messages"],
files=data.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
)
# Initialize context, and citations
contexts = []
citations = []
print(file_handler)
if isinstance(response, str):
context += ("\n" if context != "" else "") + response
if citation:
citations.append(citation)
show_citations = True
if file_handler:
skip_files = True
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"]
print(f"tool_context: {context}")
# If files field is present, generate RAG completions
# If skip_files is True, skip the RAG completions
if "files" in data:
if not skip_files:
data = {**data}
rag_context, rag_citations = get_rag_context(
files=data["files"],
messages=data["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
try:
body, flags = await chat_completion_functions_handler(
body, model, user, __event_emitter__, __event_call__
)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
log.debug(f"rag_context: {rag_context}, citations: {citations}")
try:
body, flags = await chat_completion_tools_handler(
body, user, __event_emitter__, __event_call__
)
if rag_citations:
citations.extend(rag_citations)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
del data["files"]
try:
body, flags = await chat_completion_files_handler(body)
if show_citations and len(citations) > 0:
data_items.append({"citations": citations})
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
system_prompt, data["messages"]
# If context is not empty, insert it into the messages
if len(contexts) > 0:
context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"])
body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# If there are citations, add them to the data_items
if len(citations) > 0:
data_items.append({"citations": citations})
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
......@@ -654,9 +764,7 @@ app.add_middleware(ChatCompletionMiddleware)
##################################
def filter_pipeline(payload, user):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
def get_sorted_filters(model_id):
filters = [
model
for model in app.state.MODELS.values()
......@@ -672,6 +780,13 @@ def filter_pipeline(payload, user):
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
return sorted_filters
def filter_pipeline(payload, user):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id)
model = app.state.MODELS[model_id]
......@@ -704,25 +819,12 @@ def filter_pipeline(payload, user):
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
except:
pass
res = r.json()
if "detail" in res:
raise Exception(r.status_code, res["detail"])
else:
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]
if "task" in payload:
del payload["task"]
if "pipeline" not in app.state.MODELS[model_id] and "task" in payload:
del payload["task"]
return payload
......@@ -787,6 +889,14 @@ app.add_middleware(
)
@app.middleware("http")
async def commit_session_after_request(request: Request, call_next):
response = await call_next(request)
log.debug("Commit session after request")
Session.commit()
return response
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
......@@ -863,12 +973,16 @@ async def get_all_models():
model["info"] = custom_model.model_dump()
else:
owned_by = "openai"
pipe = None
for model in models:
if (
custom_model.base_model_id == model["id"]
or custom_model.base_model_id == model["id"].split(":")[0]
):
owned_by = model["owned_by"]
if "pipe" in model:
pipe = model["pipe"]
break
models.append(
......@@ -880,11 +994,11 @@ async def get_all_models():
"owned_by": owned_by,
"info": custom_model.model_dump(),
"preset": True,
**({"pipe": pipe} if pipe is not None else {}),
}
)
app.state.MODELS = {model["id"]: model for model in models}
webui_app.state.MODELS = app.state.MODELS
return models
......@@ -945,22 +1059,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
)
model = app.state.MODELS[model_id]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
sorted_filters = get_sorted_filters(model_id)
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
......@@ -1008,6 +1107,25 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
else:
pass
async def __event_emitter__(event_data):
await sio.emit(
"chat-events",
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"data": event_data,
},
to=data["session_id"],
)
async def __event_call__(event_data):
response = await sio.call(
"chat-events",
{"chat_id": data["chat_id"], "message_id": data["id"], "data": event_data},
to=data["session_id"],
)
return response
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
......@@ -1032,68 +1150,74 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not filter:
continue
try:
if hasattr(function_module, "outlet"):
outlet = function_module.outlet
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
if not hasattr(function_module, "outlet"):
continue
try:
outlet = function_module.outlet
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data
......@@ -1169,19 +1293,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
model_id = get_task_model_id(model_id)
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
......@@ -1200,7 +1314,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"stream": False,
"max_tokens": 50,
"chat_id": form_data.get("chat_id", None),
"title": True,
"task": TASKS.TITLE_GENERATION,
}
log.debug(payload)
......@@ -1213,6 +1327,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content={"detail": e.args[1]},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
......@@ -1235,19 +1352,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
model_id = get_task_model_id(model_id)
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
......@@ -1260,7 +1367,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 30,
"task": True,
"task": TASKS.QUERY_GENERATION,
}
print(payload)
......@@ -1273,6 +1380,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content={"detail": e.args[1]},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
......@@ -1289,19 +1399,9 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
model_id = get_task_model_id(model_id)
print(model_id)
model = app.state.MODELS[model_id]
template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
......@@ -1324,7 +1424,7 @@ Message: """{{prompt}}"""
"stream": False,
"max_tokens": 4,
"chat_id": form_data.get("chat_id", None),
"task": True,
"task": TASKS.EMOJI_GENERATION,
}
log.debug(payload)
......@@ -1337,6 +1437,9 @@ Message: """{{prompt}}"""
content={"detail": e.args[1]},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
......@@ -1353,22 +1456,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
model_id = get_task_model_id(model_id)
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try:
context, citation, file_handler = await get_function_call_response(
context, _, _ = await get_function_call_response(
form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],
......@@ -1432,6 +1526,7 @@ async def upload_pipeline(
os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename)
r = None
try:
# Save the uploaded file
with open(file_path, "wb") as buffer:
......@@ -1455,7 +1550,9 @@ async def upload_pipeline(
print(f"Connection error: {e}")
detail = "Pipeline not found"
status_code = status.HTTP_404_NOT_FOUND
if r is not None:
status_code = r.status_code
try:
res = r.json()
if "detail" in res:
......@@ -1464,7 +1561,7 @@ async def upload_pipeline(
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
status_code=status_code,
detail=detail,
)
finally:
......@@ -1563,8 +1660,6 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_user)):
r = None
try:
urlIdx
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
......@@ -1596,7 +1691,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
@app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
):
models = await get_all_models()
r = None
......@@ -1634,7 +1731,9 @@ async def get_pipeline_valves(
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
):
models = await get_all_models()
......@@ -1920,7 +2019,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_url = user_data.get("picture", "")
picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
try:
......@@ -1940,6 +2040,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
role = (
"admin"
if Users.get_num_users() == 0
......@@ -1950,7 +2051,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
name=user_data.get(username_claim, "User"),
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,
......@@ -2008,7 +2109,7 @@ async def get_opensearch_xml():
<ShortName>{WEBUI_NAME}</ShortName>
<Description>Search {WEBUI_NAME}</Description>
<InputEncoding>UTF-8</InputEncoding>
<Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image>
<Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/static/favicon.png</Image>
<Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
<moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
</OpenSearchDescription>
......@@ -2021,6 +2122,12 @@ async def healthcheck():
return {"status": True}
@app.get("/health/db")
async def healthcheck_with_db():
Session.execute(text("SELECT 1;")).all()
return {"status": True}
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
......
Generic single-database configuration.
Create new migrations with
DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from apps.webui.models.auths import Auth
from apps.webui.models.chats import Chat
from apps.webui.models.documents import Document
from apps.webui.models.memories import Memory
from apps.webui.models.models import Model
from apps.webui.models.prompts import Prompt
from apps.webui.models.tags import Tag, ChatIdTag
from apps.webui.models.tools import Tool
from apps.webui.models.users import User
from apps.webui.models.files import File
from apps.webui.models.functions import Function
from config import DATABASE_URL
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Auth.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
DB_URL = DATABASE_URL
if DB_URL:
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment